Skip to content

Commit 96eb748

Browse files
CopilotYunchuWang
andcommitted
Add recursive termination for sub-orchestrations
Co-authored-by: YunchuWang <[email protected]>
1 parent 3efe03e commit 96eb748

File tree

3 files changed

+169
-59
lines changed

3 files changed

+169
-59
lines changed

src/Client/Grpc/GrpcDurableTaskClient.cs

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ public GrpcDurableTaskClient(string name, GrpcDurableTaskClientOptions options,
5757
if (this.options.EnableEntitySupport)
5858
{
5959
this.entityClient = new GrpcDurableEntityClient(this.Name, this.DataConverter, this.sidecarClient, logger);
60-
}
61-
}
60+
}
61+
}
6262

6363
/// <inheritdoc/>
6464
public override DurableEntityClient Entities => this.entityClient
@@ -179,28 +179,28 @@ public override async Task RaiseEventAsync(
179179
await this.sidecarClient.RaiseEventAsync(request, cancellationToken: cancellation);
180180
}
181181

182-
/// <inheritdoc/>
183-
public override async Task TerminateInstanceAsync(
184-
string instanceId, TerminateInstanceOptions? options = null, CancellationToken cancellation = default)
185-
{
186-
object? output = options?.Output;
187-
bool recursive = options?.Recursive ?? false;
188-
189-
Check.NotNullOrEmpty(instanceId);
190-
Check.NotEntity(this.options.EnableEntitySupport, instanceId);
191-
192-
this.logger.TerminatingInstance(instanceId);
193-
194-
string? serializedOutput = this.DataConverter.Serialize(output);
195-
await this.sidecarClient.TerminateInstanceAsync(
196-
new P.TerminateRequest
197-
{
198-
InstanceId = instanceId,
199-
Output = serializedOutput,
200-
Recursive = recursive,
201-
},
202-
cancellationToken: cancellation);
203-
}
182+
/// <inheritdoc/>
183+
public override async Task TerminateInstanceAsync(
184+
string instanceId, TerminateInstanceOptions? options = null, CancellationToken cancellation = default)
185+
{
186+
object? output = options?.Output;
187+
bool recursive = options?.Recursive ?? false;
188+
189+
Check.NotNullOrEmpty(instanceId);
190+
Check.NotEntity(this.options.EnableEntitySupport, instanceId);
191+
192+
this.logger.TerminatingInstance(instanceId);
193+
194+
string? serializedOutput = this.DataConverter.Serialize(output);
195+
await this.sidecarClient.TerminateInstanceAsync(
196+
new P.TerminateRequest
197+
{
198+
InstanceId = instanceId,
199+
Output = serializedOutput,
200+
Recursive = recursive,
201+
},
202+
cancellationToken: cancellation);
203+
}
204204

205205
/// <inheritdoc/>
206206
public override async Task SuspendInstanceAsync(
@@ -598,11 +598,11 @@ async Task<PurgeResult> PurgeInstancesCoreAsync(
598598
throw new OperationCanceledException(
599599
$"The {nameof(this.PurgeAllInstancesAsync)} operation was canceled.", e, cancellation);
600600
}
601-
}
602-
603-
OrchestrationMetadata CreateMetadata(P.OrchestrationState state, bool includeInputsAndOutputs)
604-
{
605-
OrchestrationMetadata metadata = new OrchestrationMetadata(state.Name, state.InstanceId)
601+
}
602+
603+
OrchestrationMetadata CreateMetadata(P.OrchestrationState state, bool includeInputsAndOutputs)
604+
{
605+
OrchestrationMetadata metadata = new OrchestrationMetadata(state.Name, state.InstanceId)
606606
{
607607
CreatedAt = state.CreatedTimestamp.ToDateTimeOffset(),
608608
LastUpdatedAt = state.LastUpdatedTimestamp.ToDateTimeOffset(),
@@ -614,7 +614,8 @@ OrchestrationMetadata CreateMetadata(P.OrchestrationState state, bool includeInp
614614
DataConverter = includeInputsAndOutputs ? this.DataConverter : null,
615615
Tags = new Dictionary<string, string>(state.Tags),
616616
};
617-
618-
return metadata;
619-
}
620-
}
617+
618+
return metadata;
619+
}
620+
621+
}

src/InProcessTestHost/Sidecar/Grpc/TaskHubGrpcServer.cs

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ public class TaskHubGrpcServer : P.TaskHubSidecarService.TaskHubSidecarServiceBa
4343
readonly IsConnectedSignal isConnectedSignal = new();
4444
readonly SemaphoreSlim sendWorkItemLock = new(initialCount: 1);
4545
readonly ConcurrentDictionary<string, List<P.HistoryEvent>> streamingPastEvents = new(StringComparer.OrdinalIgnoreCase);
46+
readonly ConcurrentDictionary<string, HashSet<string>> subOrchestrationChildren = new(StringComparer.OrdinalIgnoreCase);
4647

4748
volatile bool supportsHistoryStreaming;
4849

@@ -300,6 +301,22 @@ await this.client.SendTaskOrchestrationMessageAsync(
300301
await this.client.ForceTerminateTaskOrchestrationAsync(
301302
request.InstanceId,
302303
request.Output);
304+
305+
if (request.Recursive &&
306+
this.subOrchestrationChildren.TryGetValue(request.InstanceId, out HashSet<string>? childSet))
307+
{
308+
foreach (string childId in childSet)
309+
{
310+
P.TerminateRequest childRequest = new()
311+
{
312+
InstanceId = childId,
313+
Output = request.Output,
314+
Recursive = true,
315+
};
316+
317+
await this.TerminateInstance(childRequest, context);
318+
}
319+
}
303320
}
304321
catch (Exception e)
305322
{
@@ -574,6 +591,18 @@ static P.GetInstanceResponse CreateGetInstanceResponse(OrchestrationState state,
574591
OrchestrationActivityStartTime = request.OrchestrationTraceContext?.SpanStartTime?.ToDateTimeOffset(),
575592
};
576593

594+
if (this.subOrchestrationChildren.TryGetValue(request.InstanceId, out HashSet<string>? children))
595+
{
596+
foreach (P.OrchestratorAction action in request.Actions)
597+
{
598+
if (action.OrchestratorActionTypeCase == P.OrchestratorAction.OrchestratorActionTypeOneofCase.CreateSubOrchestration &&
599+
!string.IsNullOrEmpty(action.CreateSubOrchestration.InstanceId))
600+
{
601+
children.Add(action.CreateSubOrchestration.InstanceId);
602+
}
603+
}
604+
}
605+
577606
tcs.TrySetResult(result);
578607

579608
return EmptyCompleteTaskResponse;
@@ -730,27 +759,46 @@ async Task<GrpcOrchestratorExecutionResult> ITaskExecutor.ExecuteOrchestrator(
730759

731760
try
732761
{
762+
List<P.HistoryEvent> protoNewEvents = newEvents.Select(ProtobufUtils.ToHistoryEventProto).ToList();
763+
List<P.HistoryEvent> protoPastEvents = pastEvents.Select(ProtobufUtils.ToHistoryEventProto).ToList();
764+
List<P.HistoryEvent> allEvents = protoPastEvents.Concat(protoNewEvents).ToList();
765+
766+
HashSet<string> children = this.subOrchestrationChildren.GetOrAdd(instance.InstanceId, _ => new(StringComparer.OrdinalIgnoreCase));
767+
foreach (P.HistoryEvent e in protoNewEvents)
768+
{
769+
if (e.SubOrchestrationInstanceCreated?.InstanceId is string subId && !string.IsNullOrEmpty(subId))
770+
{
771+
children.Add(subId);
772+
}
773+
}
774+
733775
var orkRequest = new P.OrchestratorRequest
734776
{
735777
InstanceId = instance.InstanceId,
736778
ExecutionId = instance.ExecutionId,
737-
NewEvents = { newEvents.Select(ProtobufUtils.ToHistoryEventProto) },
779+
NewEvents = { protoNewEvents },
738780
OrchestrationTraceContext = orchestrationTraceContext,
739781
};
740782

741783
// Decide whether to stream based on total size of past events (> 1MiB)
742-
List<P.HistoryEvent> protoPastEvents = pastEvents.Select(ProtobufUtils.ToHistoryEventProto).ToList();
743784
int totalBytes = 0;
744-
foreach (P.HistoryEvent ev in protoPastEvents)
785+
foreach (P.HistoryEvent ev in allEvents)
745786
{
746787
totalBytes += ev.CalculateSize();
747788
}
748789

749-
if (this.supportsHistoryStreaming && totalBytes > (1024))
790+
this.streamingPastEvents[instance.InstanceId] = allEvents;
791+
792+
if (this.supportsHistoryStreaming)
750793
{
751-
orkRequest.RequiresHistoryStreaming = true;
752-
// Store past events to serve via StreamInstanceHistory
753-
this.streamingPastEvents[instance.InstanceId] = protoPastEvents;
794+
if (totalBytes > (1024))
795+
{
796+
orkRequest.RequiresHistoryStreaming = true;
797+
}
798+
else
799+
{
800+
orkRequest.PastEvents.AddRange(protoPastEvents);
801+
}
754802
}
755803
else
756804
{

test/Grpc.IntegrationTests/OrchestrationPatterns.cs

Lines changed: 80 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
// Copyright (c) Microsoft Corporation.
22
// Licensed under the MIT License.
33

4-
using System.Text.Json;
5-
using System.Text.Json.Nodes;
6-
using Microsoft.DurableTask.Client;
7-
using Microsoft.DurableTask.Tests.Logging;
8-
using Microsoft.DurableTask.Worker;
9-
using Microsoft.Extensions.DependencyInjection;
10-
using Xunit.Abstractions;
4+
using System.Text.Json;
5+
using System.Text.Json.Nodes;
6+
using DurableTask.Core.History;
7+
using Microsoft.DurableTask.Client;
8+
using Microsoft.DurableTask.Tests.Logging;
9+
using Microsoft.DurableTask.Worker;
10+
using Microsoft.Extensions.DependencyInjection;
11+
using Xunit.Abstractions;
1112

1213
namespace Microsoft.DurableTask.Grpc.Tests;
1314

@@ -422,11 +423,11 @@ public async Task ExternalEventsInParallel(int eventCount)
422423
}
423424

424425
[Fact]
425-
public async Task Termination()
426-
{
427-
TaskName orchestrationName = nameof(Termination);
428-
await using HostTestLifetime server = await this.StartWorkerAsync(b =>
429-
{
426+
public async Task Termination()
427+
{
428+
TaskName orchestrationName = nameof(Termination);
429+
await using HostTestLifetime server = await this.StartWorkerAsync(b =>
430+
{
430431
b.AddTasks(tasks => tasks.AddOrchestratorFunc(
431432
orchestrationName, ctx => ctx.CreateTimer(TimeSpan.FromSeconds(3), CancellationToken.None)));
432433
});
@@ -446,13 +447,73 @@ public async Task Termination()
446447
JsonElement actualOutput = metadata.ReadOutputAs<JsonElement>();
447448
string? actualQuote = actualOutput.GetProperty("quote").GetString();
448449
Assert.NotNull(actualQuote);
449-
Assert.Equal(expectedOutput.quote, actualQuote);
450-
}
451-
452-
[Fact]
453-
public async Task ContinueAsNew()
454-
{
455-
TaskName orchestratorName = nameof(ContinueAsNew);
450+
Assert.Equal(expectedOutput.quote, actualQuote);
451+
}
452+
453+
[Fact]
454+
public async Task Termination_RecursiveStopsSubOrchestrations()
455+
{
456+
// Arrange
457+
TaskName parentName = nameof(Termination_RecursiveStopsSubOrchestrations) + "_Parent";
458+
TaskName childName = nameof(Termination_RecursiveStopsSubOrchestrations) + "_Child";
459+
string parentInstanceId = "parent-recursive-termination";
460+
string childInstanceIdOne = "child-recursive-termination-1";
461+
string childInstanceIdTwo = "child-recursive-termination-2";
462+
463+
await using HostTestLifetime server = await this.StartWorkerAsync(b =>
464+
{
465+
b.AddTasks(tasks => tasks
466+
.AddOrchestratorFunc(parentName, async ctx =>
467+
{
468+
_ = ctx.CallSubOrchestratorAsync<object>(
469+
childName,
470+
null,
471+
new SubOrchestrationOptions(instanceId: childInstanceIdOne));
472+
_ = ctx.CallSubOrchestratorAsync<object>(
473+
childName,
474+
null,
475+
new SubOrchestrationOptions(instanceId: childInstanceIdTwo));
476+
477+
await ctx.CreateTimer(TimeSpan.FromMinutes(1), CancellationToken.None);
478+
})
479+
.AddOrchestratorFunc(childName, async ctx =>
480+
{
481+
await ctx.CreateTimer(TimeSpan.FromMinutes(1), CancellationToken.None);
482+
}));
483+
});
484+
485+
await server.Client.ScheduleNewOrchestrationInstanceAsync(
486+
parentName,
487+
input: null,
488+
options: new StartOrchestrationOptions(parentInstanceId));
489+
490+
await server.Client.WaitForInstanceStartAsync(parentInstanceId, this.TimeoutToken);
491+
await server.Client.WaitForInstanceStartAsync(childInstanceIdOne, this.TimeoutToken);
492+
await server.Client.WaitForInstanceStartAsync(childInstanceIdTwo, this.TimeoutToken);
493+
494+
// Act
495+
await server.Client.TerminateInstanceAsync(
496+
parentInstanceId,
497+
new TerminateInstanceOptions(Recursive: true),
498+
this.TimeoutToken);
499+
500+
OrchestrationMetadata parentMetadata =
501+
await server.Client.WaitForInstanceCompletionAsync(parentInstanceId, this.TimeoutToken);
502+
OrchestrationMetadata childMetadataOne =
503+
await server.Client.WaitForInstanceCompletionAsync(childInstanceIdOne, this.TimeoutToken);
504+
OrchestrationMetadata childMetadataTwo =
505+
await server.Client.WaitForInstanceCompletionAsync(childInstanceIdTwo, this.TimeoutToken);
506+
507+
// Assert
508+
Assert.Equal(OrchestrationRuntimeStatus.Terminated, parentMetadata.RuntimeStatus);
509+
Assert.Equal(OrchestrationRuntimeStatus.Terminated, childMetadataOne.RuntimeStatus);
510+
Assert.Equal(OrchestrationRuntimeStatus.Terminated, childMetadataTwo.RuntimeStatus);
511+
}
512+
513+
[Fact]
514+
public async Task ContinueAsNew()
515+
{
516+
TaskName orchestratorName = nameof(ContinueAsNew);
456517

457518
await using HostTestLifetime server = await this.StartWorkerAsync(b =>
458519
{

0 commit comments

Comments
 (0)