Skip to content
10 changes: 10 additions & 0 deletions src/EFCore.Relational/Update/Internal/SharedTableEntryMap.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ private IUpdateEntry GetMainEntry(IUpdateEntry entry)
{
return GetMainEntry(principalEntry);
}

// When TPH entity replacement causes FindPrincipal to return null, check if the replacing
// entry's SharedIdentityEntry has a compatible principal and use it as the main entry.
var keyValues = foreignKey.Properties.Select(p => entry.GetCurrentValue(p)).ToArray();
var replacingPrincipal = _updateAdapter.TryGetEntry(foreignKey.PrincipalKey, keyValues);
if (replacingPrincipal?.SharedIdentityEntry != null
&& foreignKey.PrincipalEntityType.IsAssignableFrom(replacingPrincipal.SharedIdentityEntry.EntityType))
{
return GetMainEntry(replacingPrincipal);
}
}

return entry;
Expand Down
117 changes: 117 additions & 0 deletions test/EFCore.Relational.Tests/Update/CommandBatchPreparerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,123 @@ public void BatchCommands_handles_null_values_when_sensitive_logging_enabled()
Assert.DoesNotContain("Object reference not set", exception.Message);
}

[ConditionalFact]
public void BatchCommands_creates_valid_batch_for_replaced_entity_with_TPH_and_owned_type_and_concurrency_token()
{
var modelBuilder = FakeRelationalTestHelpers.Instance.CreateConventionBuilder();

modelBuilder.Entity<EntityBase37588>(b =>
{
b.HasDiscriminator<string>("Type")
.HasValue<EntityA37588>(nameof(EntityA37588))
.HasValue<EntityB37588>(nameof(EntityB37588));

b.HasKey(x => x.Id);
b.Property(x => x.Id).HasMaxLength(10);

b.Property(x => x.RowVersion)
.IsRequired()
.IsRowVersion()
.IsConcurrencyToken()
.HasConversion<byte[]>();
});

modelBuilder.Entity<EntityA37588>(b =>
{
b.Property(x => x.SomeValue);
b.OwnsOne(x => x.Owned, x =>
{
x.Property(p => p.CreationDate);
});
});

modelBuilder.Entity<EntityB37588>(b =>
{
b.Property(x => x.Name).HasMaxLength(100);
});

var model = modelBuilder.Model.FinalizeModel();
var currentDbContext = CreateContextServices(model).GetRequiredService<ICurrentDbContext>();
var stateManager = currentDbContext.GetDependencies().StateManager;

// Create "existing" EntityA with an owned entity
var entityA = new EntityA37588 { Id = "SOMEID", SomeValue = true, Owned = new OwnedEntity37588 { CreationDate = DateTime.UtcNow } };
var entityAEntry = stateManager.GetOrCreateEntry(entityA);
entityAEntry.SetEntityState(EntityState.Unchanged);

// Track the owned entity
var ownedEntityType = model.FindEntityType(typeof(OwnedEntity37588), "Owned", model.FindEntityType(typeof(EntityA37588)));
var ownedEntry = stateManager.GetOrCreateEntry(entityA.Owned, ownedEntityType);
ownedEntry.SetEntityState(EntityState.Unchanged);

// Delete EntityA (owned will cascade)
entityAEntry.SetEntityState(EntityState.Deleted);
ownedEntry.SetEntityState(EntityState.Deleted);

// Add EntityB with the same PK
var entityB = new EntityB37588 { Id = "SOMEID", Name = "Any" };
var entityBEntry = stateManager.GetOrCreateEntry(entityB);
entityBEntry.SetEntityState(EntityState.Added);

// Verify SharedIdentityEntry is set bidirectionally
Assert.NotNull(entityBEntry.SharedIdentityEntry);
Assert.Same(entityAEntry, entityBEntry.SharedIdentityEntry);
Assert.NotNull(entityAEntry.SharedIdentityEntry);
Assert.Same(entityBEntry, entityAEntry.SharedIdentityEntry);

var modelData = new UpdateAdapter(stateManager);

var commandBatches = CreateBatches(
stateManager.GetEntriesToSave(cascadeChanges: true).ToArray(), modelData);

// Should create valid batch(es) without errors
Assert.NotEmpty(commandBatches);

// Find the command for the replaced entity
var allCommands = commandBatches.SelectMany(b => b.ModificationCommands).ToList();

// The owned entity entry should be part of the same Modified command, not a separate Deleted command
var deletedCommands = allCommands.Where(c => c.EntityState == EntityState.Deleted).ToList();
Assert.Empty(deletedCommands);

Assert.Single(allCommands);
var modifiedCommand = Assert.Single(allCommands);
Assert.Equal(EntityState.Modified, modifiedCommand.EntityState);

// The modified command should contain both EntityB and OwnedEntity entries
Assert.True(modifiedCommand.Entries.Count() >= 2,
$"Expected at least 2 entries in Modified command, but got {modifiedCommand.Entries.Count()}. " +
$"Total commands: {allCommands.Count}, states: [{string.Join(", ", allCommands.Select(c => c.EntityState))}]");

// RowVersion should be a condition (used in WHERE clause)
var rvModification = modifiedCommand.ColumnModifications
.FirstOrDefault(cm => cm.ColumnName == "RowVersion");
Assert.NotNull(rvModification);
Assert.True(rvModification.IsCondition);
}

private abstract class EntityBase37588
{
public string Id { get; set; }
public long RowVersion { get; set; }
}

private class OwnedEntity37588
{
public DateTime CreationDate { get; set; }
}

private class EntityA37588 : EntityBase37588
{
public bool SomeValue { get; set; }
public OwnedEntity37588 Owned { get; set; }
}

private class EntityB37588 : EntityBase37588
{
public string Name { get; set; }
}

private class AnotherFakeEntity
{
public int Id { get; set; }
Expand Down
Loading