Skip to content
15 changes: 15 additions & 0 deletions src/EFCore.Relational/Update/Internal/SharedTableEntryMap.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,21 @@ private IUpdateEntry GetMainEntry(IUpdateEntry entry)
{
return GetMainEntry(principalEntry);
}

// When TPH entity replacement causes FindPrincipal to return null (the replacing entry
// is an incompatible sibling type), check if the replacing entry's SharedIdentityEntry
// is a compatible principal, and use the replacing entry as the main entry.
if (principalEntry == null)
{
var keyValues = foreignKey.Properties.Select(p => entry.GetCurrentValue(p)).ToArray();

var identityEntry = _updateAdapter.TryGetEntry(foreignKey.PrincipalKey, keyValues);
if (identityEntry?.SharedIdentityEntry != null
&& foreignKey.PrincipalEntityType.IsAssignableFrom(identityEntry.SharedIdentityEntry.EntityType))
{
return GetMainEntry(identityEntry);
}
}
}

return entry;
Expand Down
115 changes: 115 additions & 0 deletions test/EFCore.Relational.Tests/Update/CommandBatchPreparerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,121 @@ public void BatchCommands_handles_null_values_when_sensitive_logging_enabled()
Assert.DoesNotContain("Object reference not set", exception.Message);
}

[ConditionalFact]
public void BatchCommands_creates_valid_batch_for_replaced_entity_with_TPH_and_owned_type_and_concurrency_token()
{
var modelBuilder = FakeRelationalTestHelpers.Instance.CreateConventionBuilder();

modelBuilder.Entity<EntityBase37169>(b =>
{
b.HasDiscriminator<string>("Type")
.HasValue<EntityA37169>(nameof(EntityA37169))
.HasValue<EntityB37169>(nameof(EntityB37169));

b.HasKey(x => x.Id);
b.Property(x => x.Id).HasMaxLength(10);

b.Property(x => x.RowVersion)
.IsRequired()
.IsRowVersion()
.IsConcurrencyToken()
.HasConversion<byte[]>();
});

modelBuilder.Entity<EntityA37169>(b =>
{
b.Property(x => x.SomeValue);
b.OwnsOne(x => x.Owned, x =>
{
x.Property(p => p.CreationDate);
});
});

modelBuilder.Entity<EntityB37169>(b =>
{
b.Property(x => x.Name).HasMaxLength(100);
});

var model = modelBuilder.Model.FinalizeModel();
var currentDbContext = CreateContextServices(model).GetRequiredService<ICurrentDbContext>();
var stateManager = currentDbContext.GetDependencies().StateManager;

// Create "existing" EntityA with an owned entity
var entityA = new EntityA37169 { Id = "SOMEID", SomeValue = true, Owned = new OwnedEntity37169 { CreationDate = DateTime.UtcNow } };
var entityAEntry = stateManager.GetOrCreateEntry(entityA);
entityAEntry.SetEntityState(EntityState.Unchanged);

// Track the owned entity
var ownedEntityType = model.FindEntityType(typeof(OwnedEntity37169), "Owned", model.FindEntityType(typeof(EntityA37169)));
var ownedEntry = stateManager.GetOrCreateEntry(entityA.Owned, ownedEntityType);
ownedEntry.SetEntityState(EntityState.Unchanged);

// Delete EntityA (owned will cascade)
entityAEntry.SetEntityState(EntityState.Deleted);
ownedEntry.SetEntityState(EntityState.Deleted);

// Add EntityB with the same PK
var entityB = new EntityB37169 { Id = "SOMEID", Name = "Any" };
var entityBEntry = stateManager.GetOrCreateEntry(entityB);
entityBEntry.SetEntityState(EntityState.Added);

// Verify SharedIdentityEntry is set bidirectionally
Assert.NotNull(entityBEntry.SharedIdentityEntry);
Assert.Same(entityAEntry, entityBEntry.SharedIdentityEntry);
Assert.NotNull(entityAEntry.SharedIdentityEntry);
Assert.Same(entityBEntry, entityAEntry.SharedIdentityEntry);

var modelData = new UpdateAdapter(stateManager);

var commandBatches = CreateBatches(
stateManager.GetEntriesToSave(cascadeChanges: true).ToArray(), modelData);

// Should create valid batch(es) without errors
Assert.NotEmpty(commandBatches);

// Find the command for the replaced entity
var allCommands = commandBatches.SelectMany(b => b.ModificationCommands).ToList();

// The owned entity entry should be part of the same Modified command, not a separate Deleted command
var deletedCommands = allCommands.Where(c => c.EntityState == EntityState.Deleted).ToList();
Assert.Empty(deletedCommands);

var modifiedCommand = Assert.Single(allCommands, c => c.EntityState == EntityState.Modified);

// The modified command should contain both EntityB and OwnedEntity entries
Assert.True(modifiedCommand.Entries.Count() >= 2,
$"Expected at least 2 entries in Modified command, but got {modifiedCommand.Entries.Count()}. " +
$"Total commands: {allCommands.Count}, states: [{string.Join(", ", allCommands.Select(c => c.EntityState))}]");

// RowVersion should be a condition (used in WHERE clause)
var rvModification = modifiedCommand.ColumnModifications
.FirstOrDefault(cm => cm.ColumnName == "RowVersion");
Assert.NotNull(rvModification);
Assert.True(rvModification.IsCondition);
}

private abstract class EntityBase37169
{
public string Id { get; set; }
public long RowVersion { get; set; }
}

private class OwnedEntity37169
{
public DateTime CreationDate { get; set; }
}

private class EntityA37169 : EntityBase37169
{
public bool SomeValue { get; set; }
public OwnedEntity37169 Owned { get; set; }
}

private class EntityB37169 : EntityBase37169
{
public string Name { get; set; }
}

private class AnotherFakeEntity
{
public int Id { get; set; }
Expand Down
Loading