diff --git a/src/Abstractions/TaskOrchestrationContext.cs b/src/Abstractions/TaskOrchestrationContext.cs
index 95cc116e..9cf5fe7e 100644
--- a/src/Abstractions/TaskOrchestrationContext.cs
+++ b/src/Abstractions/TaskOrchestrationContext.cs
@@ -222,13 +222,28 @@ public virtual Task CreateTimer(TimeSpan delay, CancellationToken cancellationTo
///
/// The amount of time to wait before cancelling the external event task.
///
- public async Task WaitForExternalEvent(string eventName, TimeSpan timeout)
+ public Task WaitForExternalEvent(string eventName, TimeSpan timeout)
+ {
+ return this.WaitForExternalEvent(eventName, timeout, CancellationToken.None);
+ }
+
+ ///
+ /// The name of the event to wait for. Event names are case-insensitive. External event names can be reused any
+ /// number of times; they are not required to be unique.
+ ///
+ /// The amount of time to wait before cancelling the external event task.
+ /// A CancellationToken to use to abort waiting for the event.
+ ///
+ public async Task WaitForExternalEvent(string eventName, TimeSpan timeout, CancellationToken cancellationToken)
{
// Timeouts are implemented using durable timers.
using CancellationTokenSource timerCts = new();
Task timeoutTask = this.CreateTimer(timeout, timerCts.Token);
- using CancellationTokenSource eventCts = new();
+ // Create a linked cancellation token source from the external cancellation token.
+ // This allows us to cancel the event wait either when the external token is cancelled
+ // or when the timeout fires (by calling eventCts.Cancel()).
+ using CancellationTokenSource eventCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
Task externalEventTask = this.WaitForExternalEvent(eventName, eventCts.Token);
// Wait for either task to complete and then cancel the one that didn't.
diff --git a/test/Grpc.IntegrationTests/OrchestrationPatterns.cs b/test/Grpc.IntegrationTests/OrchestrationPatterns.cs
index be7c0d1d..d52c271c 100644
--- a/test/Grpc.IntegrationTests/OrchestrationPatterns.cs
+++ b/test/Grpc.IntegrationTests/OrchestrationPatterns.cs
@@ -1338,4 +1338,172 @@ public async Task CatchingActivityExceptionsByType()
Assert.Equal("Success", results[2]);
Assert.Equal("Caught base Exception", results[3]);
}
+
+ [Fact]
+ public async Task WaitForExternalEvent_WithTimeoutAndCancellationToken_EventWins()
+ {
+ const string EventName = "TestEvent";
+ const string EventPayload = "test-payload";
+ TaskName orchestratorName = nameof(WaitForExternalEvent_WithTimeoutAndCancellationToken_EventWins);
+
+ await using HostTestLifetime server = await this.StartWorkerAsync(b =>
+ {
+ b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx =>
+ {
+ using CancellationTokenSource cts = new();
+ Task eventTask = ctx.WaitForExternalEvent(EventName, TimeSpan.FromDays(7), cts.Token);
+ string result = await eventTask;
+ return result;
+ }));
+ });
+
+ string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName);
+ await server.Client.WaitForInstanceStartAsync(instanceId, this.TimeoutToken);
+
+ // Send event - should complete the wait
+ await server.Client.RaiseEventAsync(instanceId, EventName, EventPayload);
+
+ OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync(
+ instanceId, getInputsAndOutputs: true, this.TimeoutToken);
+ Assert.NotNull(metadata);
+ Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus);
+
+ string? result = metadata.ReadOutputAs();
+ Assert.Equal(EventPayload, result);
+ }
+
+ [Fact]
+ public async Task WaitForExternalEvent_WithTimeoutAndCancellationToken_CancellationWins()
+ {
+ TaskName orchestratorName = nameof(WaitForExternalEvent_WithTimeoutAndCancellationToken_CancellationWins);
+
+ await using HostTestLifetime server = await this.StartWorkerAsync(b =>
+ {
+ b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx =>
+ {
+ using CancellationTokenSource cts = new();
+
+ // Create two event waiters with cancellation tokens
+ Task event1Task = ctx.WaitForExternalEvent("Event1", TimeSpan.FromDays(7), cts.Token);
+
+ using CancellationTokenSource cts2 = new();
+ Task event2Task = ctx.WaitForExternalEvent("Event2", TimeSpan.FromDays(7), cts2.Token);
+
+ // Wait for any to complete
+ Task winner = await Task.WhenAny(event1Task, event2Task);
+
+ // Cancel the other one
+ if (winner == event1Task)
+ {
+ cts2.Cancel();
+ return $"Event1: {await event1Task}";
+ }
+ else
+ {
+ cts.Cancel();
+ return $"Event2: {await event2Task}";
+ }
+ }));
+ });
+
+ string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName);
+ await server.Client.WaitForInstanceStartAsync(instanceId, this.TimeoutToken);
+
+ // Send Event1 - should complete and cancel Event2
+ await server.Client.RaiseEventAsync(instanceId, "Event1", "first-event");
+
+ OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync(
+ instanceId, getInputsAndOutputs: true, this.TimeoutToken);
+ Assert.NotNull(metadata);
+ Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus);
+
+ string? result = metadata.ReadOutputAs();
+ Assert.Equal("Event1: first-event", result);
+ }
+
+ [Fact]
+ public async Task WaitForExternalEvent_WithTimeoutAndCancellationToken_TimeoutWins()
+ {
+ const string EventName = "TestEvent";
+ TaskName orchestratorName = nameof(WaitForExternalEvent_WithTimeoutAndCancellationToken_TimeoutWins);
+
+ await using HostTestLifetime server = await this.StartWorkerAsync(b =>
+ {
+ b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx =>
+ {
+ using CancellationTokenSource cts = new();
+ Task eventTask = ctx.WaitForExternalEvent(EventName, TimeSpan.FromMilliseconds(500), cts.Token);
+
+ try
+ {
+ string result = await eventTask;
+ return $"Event: {result}";
+ }
+ catch (OperationCanceledException)
+ {
+ return "Timeout occurred";
+ }
+ }));
+ });
+
+ string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName);
+
+ OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync(
+ instanceId, getInputsAndOutputs: true, this.TimeoutToken);
+ Assert.NotNull(metadata);
+ Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus);
+
+ string? result = metadata.ReadOutputAs();
+ Assert.Equal("Timeout occurred", result);
+ }
+
+ [Fact]
+ public async Task WaitForExternalEvent_WithTimeoutAndCancellationToken_ExternalCancellationWins()
+ {
+ const string EventName = "TestEvent";
+ TaskName orchestratorName = nameof(WaitForExternalEvent_WithTimeoutAndCancellationToken_ExternalCancellationWins);
+
+ await using HostTestLifetime server = await this.StartWorkerAsync(b =>
+ {
+ b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx =>
+ {
+ using CancellationTokenSource cts = new();
+
+ // Create a timer that will fire and trigger cancellation
+ Task cancelTrigger = ctx.CreateTimer(TimeSpan.FromMilliseconds(100), CancellationToken.None);
+
+ // Wait for external event with a long timeout
+ Task eventTask = ctx.WaitForExternalEvent(EventName, TimeSpan.FromDays(7), cts.Token);
+
+ // Wait for either the cancel trigger or the event
+ Task winner = await Task.WhenAny(cancelTrigger, eventTask);
+
+ if (winner == cancelTrigger)
+ {
+ // Cancel the external cancellation token
+ cts.Cancel();
+ }
+
+ try
+ {
+ string result = await eventTask;
+ return $"Event: {result}";
+ }
+ catch (OperationCanceledException)
+ {
+ return "External cancellation occurred";
+ }
+ }));
+ });
+
+ string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName);
+
+ OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync(
+ instanceId, getInputsAndOutputs: true, this.TimeoutToken);
+ Assert.NotNull(metadata);
+ Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus);
+
+ string? result = metadata.ReadOutputAs();
+ Assert.Equal("External cancellation occurred", result);
+ }
}