Skip to content

Commit 81c9149

Browse files
authored
Merge pull request #21 from koenbeuk/trigger-priority-logic
Fixed broken recursion
2 parents 88e7f17 + b687273 commit 81c9149

File tree

5 files changed

+62
-29
lines changed

5 files changed

+62
-29
lines changed

src/EntityFrameworkCore.Triggered/Internal/ITriggerContextDiscoveryStrategy.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ namespace EntityFrameworkCore.Triggered.Internal
1010
{
1111
public interface ITriggerContextDiscoveryStrategy
1212
{
13-
IEnumerable<ITriggerContextDescriptor> Discover(TriggerOptions options, TriggerContextTracker tracker, ILogger logger);
13+
IEnumerable<IEnumerable<ITriggerContextDescriptor>> Discover(TriggerOptions options, TriggerContextTracker tracker, ILogger logger);
1414
}
1515
}

src/EntityFrameworkCore.Triggered/Internal/NonRecursiveTriggerContextDiscoveryStrategy.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ public NonRecursiveTriggerContextDiscoveryStrategy(string name)
2222
_name = name ?? throw new ArgumentNullException(nameof(name));
2323
}
2424

25-
public IEnumerable<ITriggerContextDescriptor> Discover(TriggerOptions options, TriggerContextTracker tracker, ILogger logger)
25+
public IEnumerable<IEnumerable<ITriggerContextDescriptor>> Discover(TriggerOptions options, TriggerContextTracker tracker, ILogger logger)
2626
{
2727
var changes = tracker.DiscoveredChanges ?? throw new InvalidOperationException("Trigger discovery process has not yet started. Please ensure that TriggerSession.DiscoverChanges() or TriggerSession.RaiseBeforeSaveTriggers() has been called");
2828

2929
_changesDetected(logger, changes.Count(), _name, null);
3030

31-
return changes;
31+
return Enumerable.Repeat(changes, 1);
3232
}
3333
}
3434
}

src/EntityFrameworkCore.Triggered/Internal/RecursiveTriggerContextDiscoveryStrategy.cs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public RecursiveTriggerContextDiscoveryStrategy(string name, bool skipDetectedCh
3131
_skipDetectedChanges = skipDetectedChanges;
3232
}
3333

34-
public IEnumerable<ITriggerContextDescriptor> Discover(TriggerOptions options, TriggerContextTracker tracker, ILogger logger)
34+
public IEnumerable<IEnumerable<ITriggerContextDescriptor>> Discover(TriggerOptions options, TriggerContextTracker tracker, ILogger logger)
3535
{
3636
var maxRecursion = options.MaxRecursion;
3737
_discoveryStarted(logger, _name, maxRecursion, null);
@@ -56,10 +56,7 @@ public IEnumerable<ITriggerContextDescriptor> Discover(TriggerOptions options, T
5656
{
5757
_changesDetected(logger, changes.Count(), _name, iteration, maxRecursion, null);
5858

59-
foreach (var change in changes)
60-
{
61-
yield return change;
62-
}
59+
yield return changes;
6360
}
6461
else
6562
{

src/EntityFrameworkCore.Triggered/TriggerSession.cs

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,34 @@ public async Task RaiseTriggers(Type openTriggerType, ITriggerContextDiscoverySt
4545

4646
cancellationToken.ThrowIfCancellationRequested();
4747

48-
var triggerContextDescriptors = triggerContextDiscoveryStrategy.Discover(_options, _tracker, _logger);
49-
IEnumerable<(ITriggerContextDescriptor triggerContextDescriptor, TriggerDescriptor triggerDescriptor)> triggerInvocations = triggerContextDescriptors
50-
.SelectMany(triggerContextDescriptor =>
51-
_triggerDiscoveryService
52-
.DiscoverTriggers(openTriggerType, triggerContextDescriptor.EntityType, triggerTypeDescriptorFactory)
53-
.Select(triggerDescriptor => (triggerContextDescriptor, triggerDescriptor))
54-
)
55-
.OrderBy(x => x.triggerDescriptor.Priority);
56-
57-
if (_logger.IsEnabled(LogLevel.Debug))
48+
var triggerContextDescriptorBatches = triggerContextDiscoveryStrategy.Discover(_options, _tracker, _logger);
49+
foreach (var triggerContextDescriptorBatch in triggerContextDescriptorBatches)
5850
{
51+
IEnumerable<(ITriggerContextDescriptor triggerContextDescriptor, TriggerDescriptor triggerDescriptor)> triggerInvocations = triggerContextDescriptorBatch
52+
.SelectMany(triggerContextDescriptor =>
53+
_triggerDiscoveryService
54+
.DiscoverTriggers(openTriggerType, triggerContextDescriptor.EntityType, triggerTypeDescriptorFactory)
55+
.Select(triggerDescriptor => (triggerContextDescriptor, triggerDescriptor))
56+
)
57+
.OrderBy(x => x.triggerDescriptor.Priority);
58+
59+
if (_logger.IsEnabled(LogLevel.Debug))
60+
{
61+
triggerInvocations = triggerInvocations.ToList();
62+
_logger.LogDebug("Discovered {triggers} triggers of type {openTriggerType}", triggerInvocations.Count(), openTriggerType);
63+
}
5964

60-
triggerInvocations = triggerInvocations.ToList();
61-
_logger.LogDebug("Discovered {triggers} triggers of type {openTriggerType}", triggerInvocations.Count(), openTriggerType);
62-
}
65+
foreach (var triggerInvocation in triggerInvocations)
66+
{
67+
cancellationToken.ThrowIfCancellationRequested();
6368

64-
foreach (var triggerInvocation in triggerInvocations)
65-
{
66-
cancellationToken.ThrowIfCancellationRequested();
69+
if (_logger.IsEnabled(LogLevel.Information))
70+
{
71+
_logger.LogInformation("Invoking trigger: {trigger} as {triggerType}", triggerInvocation.triggerDescriptor.GetType().Name, triggerInvocation.triggerDescriptor.TypeDescriptor.TriggerType.Name);
72+
}
6773

68-
if (_logger.IsEnabled(LogLevel.Information))
69-
{
70-
_logger.LogInformation("Invoking trigger: {trigger} as {triggerType}", triggerInvocation.triggerDescriptor.GetType().Name, triggerInvocation.triggerDescriptor.TypeDescriptor.TriggerType.Name);
74+
await triggerInvocation.triggerDescriptor.Invoke(triggerInvocation.triggerContextDescriptor.GetTriggerContext(), cancellationToken).ConfigureAwait(false);
7175
}
72-
73-
await triggerInvocation.triggerDescriptor.Invoke(triggerInvocation.triggerContextDescriptor .GetTriggerContext(), cancellationToken).ConfigureAwait(false);
7476
}
7577
}
7678

test/EntityFrameworkCore.Triggered.Tests/TriggerSessionTests.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,41 @@ public void RaiseBeforeSaveTriggers_MultipleEntities_SortByPriorities()
256256
Assert.Equal("Early", capturedInvocations[1].Item1);
257257
Assert.Equal("Late", capturedInvocations[2].Item1);
258258
Assert.Equal("Late", capturedInvocations[3].Item1);
259+
}
260+
261+
262+
[Fact]
263+
public void RaiseBeforeSaveTriggers_RecursiveAdd_RaisesSubsequentTriggers()
264+
{
265+
TestDbContext dbContext = null;
266+
267+
var trigger = new TriggerStub<TestModel> {
268+
Priority = CommonTriggerPriority.Early,
269+
BeforeSaveHandler = (context, _) => {
270+
if (context.Entity.Id == 1)
271+
{
272+
dbContext.TestModels.Add(new TestModel { Id = 2 });
273+
}
274+
return Task.CompletedTask;
275+
}
276+
};
277+
278+
var serviceProvider = new ServiceCollection()
279+
.AddSingleton<IBeforeSaveTrigger<TestModel>>(trigger)
280+
.AddTriggeredDbContext<TestDbContext>(options => {
281+
options.UseInMemoryDatabase("Test");
282+
})
283+
.BuildServiceProvider();
284+
285+
var scope = serviceProvider.CreateScope();
286+
dbContext = scope.ServiceProvider.GetRequiredService<TestDbContext>();
287+
var subject = CreateSubject(dbContext);
288+
289+
dbContext.TestModels.Add(new TestModel { Id = 1 });
290+
291+
subject.RaiseBeforeSaveTriggers();
259292

293+
Assert.Equal(2, trigger.BeforeSaveInvocations.Count);
260294
}
261295
}
262296
}

0 commit comments

Comments
 (0)