diff --git a/src/Abstractions/TaskOrchestrationContext.cs b/src/Abstractions/TaskOrchestrationContext.cs index 95cc116e..9cf5fe7e 100644 --- a/src/Abstractions/TaskOrchestrationContext.cs +++ b/src/Abstractions/TaskOrchestrationContext.cs @@ -222,13 +222,28 @@ public virtual Task CreateTimer(TimeSpan delay, CancellationToken cancellationTo /// /// The amount of time to wait before cancelling the external event task. /// - public async Task WaitForExternalEvent(string eventName, TimeSpan timeout) + public Task WaitForExternalEvent(string eventName, TimeSpan timeout) + { + return this.WaitForExternalEvent(eventName, timeout, CancellationToken.None); + } + + /// + /// The name of the event to wait for. Event names are case-insensitive. External event names can be reused any + /// number of times; they are not required to be unique. + /// + /// The amount of time to wait before cancelling the external event task. + /// A CancellationToken to use to abort waiting for the event. + /// + public async Task WaitForExternalEvent(string eventName, TimeSpan timeout, CancellationToken cancellationToken) { // Timeouts are implemented using durable timers. using CancellationTokenSource timerCts = new(); Task timeoutTask = this.CreateTimer(timeout, timerCts.Token); - using CancellationTokenSource eventCts = new(); + // Create a linked cancellation token source from the external cancellation token. + // This allows us to cancel the event wait either when the external token is cancelled + // or when the timeout fires (by calling eventCts.Cancel()). + using CancellationTokenSource eventCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); Task externalEventTask = this.WaitForExternalEvent(eventName, eventCts.Token); // Wait for either task to complete and then cancel the one that didn't. diff --git a/test/Grpc.IntegrationTests/OrchestrationPatterns.cs b/test/Grpc.IntegrationTests/OrchestrationPatterns.cs index be7c0d1d..d52c271c 100644 --- a/test/Grpc.IntegrationTests/OrchestrationPatterns.cs +++ b/test/Grpc.IntegrationTests/OrchestrationPatterns.cs @@ -1338,4 +1338,172 @@ public async Task CatchingActivityExceptionsByType() Assert.Equal("Success", results[2]); Assert.Equal("Caught base Exception", results[3]); } + + [Fact] + public async Task WaitForExternalEvent_WithTimeoutAndCancellationToken_EventWins() + { + const string EventName = "TestEvent"; + const string EventPayload = "test-payload"; + TaskName orchestratorName = nameof(WaitForExternalEvent_WithTimeoutAndCancellationToken_EventWins); + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx => + { + using CancellationTokenSource cts = new(); + Task eventTask = ctx.WaitForExternalEvent(EventName, TimeSpan.FromDays(7), cts.Token); + string result = await eventTask; + return result; + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + await server.Client.WaitForInstanceStartAsync(instanceId, this.TimeoutToken); + + // Send event - should complete the wait + await server.Client.RaiseEventAsync(instanceId, EventName, EventPayload); + + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + + string? result = metadata.ReadOutputAs(); + Assert.Equal(EventPayload, result); + } + + [Fact] + public async Task WaitForExternalEvent_WithTimeoutAndCancellationToken_CancellationWins() + { + TaskName orchestratorName = nameof(WaitForExternalEvent_WithTimeoutAndCancellationToken_CancellationWins); + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx => + { + using CancellationTokenSource cts = new(); + + // Create two event waiters with cancellation tokens + Task event1Task = ctx.WaitForExternalEvent("Event1", TimeSpan.FromDays(7), cts.Token); + + using CancellationTokenSource cts2 = new(); + Task event2Task = ctx.WaitForExternalEvent("Event2", TimeSpan.FromDays(7), cts2.Token); + + // Wait for any to complete + Task winner = await Task.WhenAny(event1Task, event2Task); + + // Cancel the other one + if (winner == event1Task) + { + cts2.Cancel(); + return $"Event1: {await event1Task}"; + } + else + { + cts.Cancel(); + return $"Event2: {await event2Task}"; + } + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + await server.Client.WaitForInstanceStartAsync(instanceId, this.TimeoutToken); + + // Send Event1 - should complete and cancel Event2 + await server.Client.RaiseEventAsync(instanceId, "Event1", "first-event"); + + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + + string? result = metadata.ReadOutputAs(); + Assert.Equal("Event1: first-event", result); + } + + [Fact] + public async Task WaitForExternalEvent_WithTimeoutAndCancellationToken_TimeoutWins() + { + const string EventName = "TestEvent"; + TaskName orchestratorName = nameof(WaitForExternalEvent_WithTimeoutAndCancellationToken_TimeoutWins); + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx => + { + using CancellationTokenSource cts = new(); + Task eventTask = ctx.WaitForExternalEvent(EventName, TimeSpan.FromMilliseconds(500), cts.Token); + + try + { + string result = await eventTask; + return $"Event: {result}"; + } + catch (OperationCanceledException) + { + return "Timeout occurred"; + } + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + + string? result = metadata.ReadOutputAs(); + Assert.Equal("Timeout occurred", result); + } + + [Fact] + public async Task WaitForExternalEvent_WithTimeoutAndCancellationToken_ExternalCancellationWins() + { + const string EventName = "TestEvent"; + TaskName orchestratorName = nameof(WaitForExternalEvent_WithTimeoutAndCancellationToken_ExternalCancellationWins); + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx => + { + using CancellationTokenSource cts = new(); + + // Create a timer that will fire and trigger cancellation + Task cancelTrigger = ctx.CreateTimer(TimeSpan.FromMilliseconds(100), CancellationToken.None); + + // Wait for external event with a long timeout + Task eventTask = ctx.WaitForExternalEvent(EventName, TimeSpan.FromDays(7), cts.Token); + + // Wait for either the cancel trigger or the event + Task winner = await Task.WhenAny(cancelTrigger, eventTask); + + if (winner == cancelTrigger) + { + // Cancel the external cancellation token + cts.Cancel(); + } + + try + { + string result = await eventTask; + return $"Event: {result}"; + } + catch (OperationCanceledException) + { + return "External cancellation occurred"; + } + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + + string? result = metadata.ReadOutputAs(); + Assert.Equal("External cancellation occurred", result); + } }