diff --git a/src/Client/Core/DurableTaskClient.cs b/src/Client/Core/DurableTaskClient.cs index 16d6d8426..439670cfc 100644 --- a/src/Client/Core/DurableTaskClient.cs +++ b/src/Client/Core/DurableTaskClient.cs @@ -399,6 +399,47 @@ public virtual Task PurgeAllInstancesAsync( throw new NotSupportedException($"{this.GetType()} does not support purging of orchestration instances."); } + /// + /// Restarts an orchestration instance with the same or a new instance ID. + /// + /// + /// + /// This method restarts an existing orchestration instance. If is true, + /// a new instance ID will be generated for the restarted orchestration. If false, the original instance ID will be reused. + /// + /// The restarted orchestration will use the same input data as the original instance. If the original orchestration + /// instance is not found, an will be thrown. + /// + /// Note that this operation is backend-specific and may not be supported by all durable task backends. + /// If the backend does not support restart operations, a will be thrown. + /// + /// + /// The ID of the orchestration instance to restart. + /// + /// If true, a new instance ID will be generated for the restarted orchestration. + /// If false, the original instance ID will be reused. + /// + /// + /// The cancellation token. This only cancels enqueueing the restart request to the backend. + /// Does not abort restarting the orchestration once enqueued. + /// + /// + /// A task that completes when the orchestration instance is successfully restarted. + /// The value of this task is the instance ID of the restarted orchestration instance. + /// + /// + /// Thrown if an orchestration with the specified was not found. + /// + /// Thrown when attempting to restart an instance using the same instance Id + /// while the instance has not yet reached a completed or terminal state. + /// + /// Thrown if the backend does not support restart operations. + public virtual Task RestartAsync( + string instanceId, + bool restartWithNewInstanceId = false, + CancellationToken cancellation = default) + => throw new NotSupportedException($"{this.GetType()} does not support orchestration restart."); + // TODO: Create task hub // TODO: Delete task hub diff --git a/src/Client/Grpc/GrpcDurableTaskClient.cs b/src/Client/Grpc/GrpcDurableTaskClient.cs index c38682a3c..c57b22f96 100644 --- a/src/Client/Grpc/GrpcDurableTaskClient.cs +++ b/src/Client/Grpc/GrpcDurableTaskClient.cs @@ -395,6 +395,42 @@ public override Task PurgeAllInstancesAsync( return this.PurgeInstancesCoreAsync(request, cancellation); } + /// + public override async Task RestartAsync( + string instanceId, + bool restartWithNewInstanceId = false, + CancellationToken cancellation = default) + { + Check.NotNullOrEmpty(instanceId); + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + + var request = new P.RestartInstanceRequest + { + InstanceId = instanceId, + RestartWithNewInstanceId = restartWithNewInstanceId, + }; + + try + { + P.RestartInstanceResponse result = await this.sidecarClient.RestartInstanceAsync( + request, cancellationToken: cancellation); + return result.InstanceId; + } + catch (RpcException e) when (e.StatusCode == StatusCode.NotFound) + { + throw new ArgumentException($"An orchestration with the instanceId {instanceId} was not found.", e); + } + catch (RpcException e) when (e.StatusCode == StatusCode.FailedPrecondition) + { + throw new InvalidOperationException($"An orchestration with the instanceId {instanceId} cannot be restarted.", e); + } + catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled) + { + throw new OperationCanceledException( + $"The {nameof(this.RestartAsync)} operation was canceled.", e, cancellation); + } + } + static AsyncDisposable GetCallInvoker(GrpcDurableTaskClientOptions options, out CallInvoker callInvoker) { if (options.Channel is GrpcChannel c) diff --git a/src/Client/OrchestrationServiceClientShim/ShimDurableTaskClient.cs b/src/Client/OrchestrationServiceClientShim/ShimDurableTaskClient.cs index bb77aab21..4fbf828d0 100644 --- a/src/Client/OrchestrationServiceClientShim/ShimDurableTaskClient.cs +++ b/src/Client/OrchestrationServiceClientShim/ShimDurableTaskClient.cs @@ -254,6 +254,59 @@ public override async Task WaitForInstanceStartAsync( } } + /// + public override async Task RestartAsync( + string instanceId, + bool restartWithNewInstanceId = false, + CancellationToken cancellation = default) + { + Check.NotNullOrEmpty(instanceId); + cancellation.ThrowIfCancellationRequested(); + + // Get the current orchestration status to retrieve the name and input + OrchestrationMetadata? status = await this.GetInstanceAsync(instanceId, getInputsAndOutputs: true, cancellation); + + if (status == null) + { + throw new ArgumentException($"An orchestration with the instanceId {instanceId} was not found."); + } + + bool isInstaceNotCompleted = status.RuntimeStatus == OrchestrationRuntimeStatus.Running || + status.RuntimeStatus == OrchestrationRuntimeStatus.Pending || + status.RuntimeStatus == OrchestrationRuntimeStatus.Suspended; + + if (isInstaceNotCompleted && !restartWithNewInstanceId) + { + throw new InvalidOperationException($"Instance '{instanceId}' cannot be restarted while it is in state '{status.RuntimeStatus}'. " + + "Wait until it has completed, or restart with a new instance ID."); + } + + // Determine the instance ID for the restarted orchestration + string newInstanceId = restartWithNewInstanceId ? Guid.NewGuid().ToString("N") : instanceId; + + OrchestrationInstance instance = new() + { + InstanceId = newInstanceId, + ExecutionId = Guid.NewGuid().ToString("N"), + }; + + // Use the original serialized input directly to avoid double serialization + // TODO: OrchestrationMetada doesn't have version property so we don't support version here. + // Issue link: https://github.com/microsoft/durabletask-dotnet/issues/463 + TaskMessage message = new() + { + OrchestrationInstance = instance, + Event = new ExecutionStartedEvent(-1, status.SerializedInput) + { + Name = status.Name, + OrchestrationInstance = instance, + }, + }; + + await this.Client.CreateTaskOrchestrationAsync(message); + return newInstanceId; + } + [return: NotNullIfNotNull("state")] OrchestrationMetadata? ToMetadata(Core.OrchestrationState? state, bool getInputsAndOutputs) { diff --git a/src/Grpc/orchestrator_service.proto b/src/Grpc/orchestrator_service.proto index 95bfeedc8..3b9c4f408 100644 --- a/src/Grpc/orchestrator_service.proto +++ b/src/Grpc/orchestrator_service.proto @@ -482,6 +482,15 @@ message PurgeInstancesResponse { google.protobuf.BoolValue isComplete = 2; } +message RestartInstanceRequest { + string instanceId = 1; + bool restartWithNewInstanceId = 2; +} + +message RestartInstanceResponse { + string instanceId = 1; +} + message CreateTaskHubRequest { bool recreateIfExists = 1; } @@ -682,6 +691,9 @@ service TaskHubSidecarService { // Rewinds an orchestration instance to last known good state and replays from there. rpc RewindInstance(RewindInstanceRequest) returns (RewindInstanceResponse); + // Restarts an orchestration instance. + rpc RestartInstance(RestartInstanceRequest) returns (RestartInstanceResponse); + // Waits for an orchestration instance to reach a running or completion state. rpc WaitForInstanceStart(GetInstanceRequest) returns (GetInstanceResponse); diff --git a/src/Grpc/versions.txt b/src/Grpc/versions.txt index 5c0de3577..e9f651378 100644 --- a/src/Grpc/versions.txt +++ b/src/Grpc/versions.txt @@ -1,2 +1,2 @@ -# The following files were downloaded from branch main at 2025-08-08 16:46:11 UTC -https://raw.githubusercontent.com/microsoft/durabletask-protobuf/e88acbd07ae38b499dbe8c4e333e9e3feeb2a9cc/protos/orchestrator_service.proto +# The following files were downloaded from branch main at 2025-09-10 22:50:45 UTC +https://raw.githubusercontent.com/microsoft/durabletask-protobuf/985035a0890575ae18be0eb2a3ac93c10824498a/protos/orchestrator_service.proto diff --git a/test/Client/Core.Tests/DependencyInjection/DefaultDurableTaskClientBuilderTests.cs b/test/Client/Core.Tests/DependencyInjection/DefaultDurableTaskClientBuilderTests.cs index b2b184cad..d5b20a120 100644 --- a/test/Client/Core.Tests/DependencyInjection/DefaultDurableTaskClientBuilderTests.cs +++ b/test/Client/Core.Tests/DependencyInjection/DefaultDurableTaskClientBuilderTests.cs @@ -144,6 +144,14 @@ public override Task WaitForInstanceStartAsync( { throw new NotImplementedException(); } + + public override Task RestartAsync( + string instanceId, + bool restartWithNewInstanceId = false, + CancellationToken cancellation = default) + { + throw new NotImplementedException(); + } } class CustomDataConverter : DataConverter diff --git a/test/Client/Core.Tests/DependencyInjection/DurableTaskClientBuilderExtensionsTests.cs b/test/Client/Core.Tests/DependencyInjection/DurableTaskClientBuilderExtensionsTests.cs index 9af6469ec..f028fa13f 100644 --- a/test/Client/Core.Tests/DependencyInjection/DurableTaskClientBuilderExtensionsTests.cs +++ b/test/Client/Core.Tests/DependencyInjection/DurableTaskClientBuilderExtensionsTests.cs @@ -174,6 +174,14 @@ public override Task WaitForInstanceStartAsync( { throw new NotImplementedException(); } + + public override Task RestartAsync( + string instanceId, + bool restartWithNewInstanceId = false, + CancellationToken cancellation = default) + { + throw new NotImplementedException(); + } } class GoodBuildTargetOptions : DurableTaskClientOptions diff --git a/test/Client/OrchestrationServiceClientShim.Tests/ShimDurableTaskClientTests.cs b/test/Client/OrchestrationServiceClientShim.Tests/ShimDurableTaskClientTests.cs index 5eba17ede..588f59ac3 100644 --- a/test/Client/OrchestrationServiceClientShim.Tests/ShimDurableTaskClientTests.cs +++ b/test/Client/OrchestrationServiceClientShim.Tests/ShimDurableTaskClientTests.cs @@ -326,6 +326,79 @@ public async Task ScheduleNewOrchestrationInstance_IdProvided_TagsProvided() await this.RunScheduleNewOrchestrationInstanceAsync("test", "input", options); } + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task RestartAsync_EndToEnd(bool restartWithNewInstanceId) + { + string originalInstanceId = "test-instance-id"; + string orchestratorName = "TestOrchestrator"; + object input = "test-input"; + string serializedInput = "\"test-input\""; + + // Create a completed orchestration state + Core.OrchestrationState originalState = CreateState(input, "test-output"); + originalState.OrchestrationInstance.InstanceId = originalInstanceId; + originalState.Name = orchestratorName; + originalState.OrchestrationStatus = Core.OrchestrationStatus.Completed; + + // Setup the mock to return the original orchestration state + this.orchestrationClient + .Setup(x => x.GetOrchestrationStateAsync(originalInstanceId, false)) + .ReturnsAsync(new List { originalState }); + + // Capture the TaskMessage for verification becasue we will create this message at RestartAsync. + TaskMessage? capturedMessage = null; + this.orchestrationClient + .Setup(x => x.CreateTaskOrchestrationAsync(It.IsAny())) + .Callback(msg => capturedMessage = msg) + .Returns(Task.CompletedTask); + + string restartedInstanceId = await this.client.RestartAsync(originalInstanceId, restartWithNewInstanceId); + + if (restartWithNewInstanceId) + { + restartedInstanceId.Should().NotBe(originalInstanceId); + } + else + { + restartedInstanceId.Should().Be(originalInstanceId); + } + + // Verify that CreateTaskOrchestrationAsync was called + this.orchestrationClient.Verify( + x => x.CreateTaskOrchestrationAsync(It.IsAny()), + Times.Once); + + // Verify the captured message details + capturedMessage.Should().NotBeNull(); + capturedMessage!.Event.Should().BeOfType(); + + var startedEvent = (ExecutionStartedEvent)capturedMessage.Event; + startedEvent.Name.Should().Be(orchestratorName); + startedEvent.Input.Should().Be(serializedInput); + // TODO: once we support version at ShimDurableTaskClient, we should check version here. + startedEvent.OrchestrationInstance.InstanceId.Should().Be(restartedInstanceId); + startedEvent.OrchestrationInstance.ExecutionId.Should().NotBeNullOrEmpty(); + startedEvent.OrchestrationInstance.ExecutionId.Should().NotBe(originalState.OrchestrationInstance.ExecutionId); + } + + [Fact] + public async Task RestartAsync_InstanceNotFound_ThrowsArgumentException() + { + string nonExistentInstanceId = "non-existent-instance-id"; + + // Setup the mock to client return empty orchestration state (instance not found) + this.orchestrationClient + .Setup(x => x.GetOrchestrationStateAsync(nonExistentInstanceId, false)) + .ReturnsAsync(new List()); + + // RestartAsync should throw an ArgumentException since the instance is not found + Func restartAction = () => this.client.RestartAsync(nonExistentInstanceId); + + await restartAction.Should().ThrowAsync() + .WithMessage($"*An orchestration with the instanceId {nonExistentInstanceId} was not found*"); + } static Core.OrchestrationState CreateState( object input, object? output = null, DateTimeOffset start = default) diff --git a/test/Grpc.IntegrationTests/GrpcDurableTaskClientIntegrationTests.cs b/test/Grpc.IntegrationTests/GrpcDurableTaskClientIntegrationTests.cs index 50fe6e070..63f92ef31 100644 --- a/test/Grpc.IntegrationTests/GrpcDurableTaskClientIntegrationTests.cs +++ b/test/Grpc.IntegrationTests/GrpcDurableTaskClientIntegrationTests.cs @@ -224,6 +224,70 @@ public async Task PurgeInstances_WithFilter_EndToEnd() } } + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task RestartAsync_EndToEnd(bool restartWithNewInstanceId) + { + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1)); + await using HostTestLifetime server = await this.StartAsync(); + + // Start an initial orchestration with shouldThrow = false to ensure it completes successfully + string originalInstanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync( + OrchestrationName, input: false); + + // Wait for it to start and then complete + await server.Client.WaitForInstanceStartAsync(originalInstanceId, default); + await server.Client.RaiseEventAsync(originalInstanceId, "event", default); + await server.Client.WaitForInstanceCompletionAsync(originalInstanceId, cts.Token); + + // Verify the original orchestration completed + OrchestrationMetadata? originalMetadata = await server.Client.GetInstanceAsync(originalInstanceId, true); + originalMetadata.Should().NotBeNull(); + originalMetadata!.RuntimeStatus.Should().Be(OrchestrationRuntimeStatus.Completed); + + // Restart the orchestration + string restartedInstanceId = await server.Client.RestartAsync(originalInstanceId, restartWithNewInstanceId); + + // Verify the restart behavior + if (restartWithNewInstanceId) + { + restartedInstanceId.Should().NotBe(originalInstanceId); + } + else + { + restartedInstanceId.Should().Be(originalInstanceId); + } + + // Complete the restarted orchestration + await server.Client.RaiseEventAsync(restartedInstanceId, "event"); + + using var completionCts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + await server.Client.WaitForInstanceCompletionAsync(restartedInstanceId, completionCts.Token); + + // Verify the restarted orchestration completed. + // Also verify input and orchestrator name are matched. + var restartedMetadata = await server.Client.GetInstanceAsync(restartedInstanceId, true); + restartedMetadata.Should().NotBeNull(); + restartedMetadata!.Name.Should().Be(OrchestrationName); + restartedMetadata.SerializedInput.Should().Be("false"); + restartedMetadata!.RuntimeStatus.Should().Be(OrchestrationRuntimeStatus.Completed); + } + + [Fact] + public async Task RestartAsync_InstanceNotFound_ThrowsArgumentException() + { + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1)); // 1-minute timeout + await using HostTestLifetime server = await this.StartAsync(); + + // Try to restart a non-existent orchestration + Func restartAction = () => server.Client.RestartAsync("non-existent-instance-id", cancellation: cts.Token); + + // Should throw ArgumentException + await restartAction.Should().ThrowAsync() + .WithMessage("*An orchestration with the instanceId non-existent-instance-id was not found*"); + } + Task StartAsync() { static async Task Orchestration(TaskOrchestrationContext context, bool shouldThrow) diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/Grpc/TaskHubGrpcServer.cs b/test/Grpc.IntegrationTests/GrpcSidecar/Grpc/TaskHubGrpcServer.cs index b035ed53f..e3a320f79 100644 --- a/test/Grpc.IntegrationTests/GrpcSidecar/Grpc/TaskHubGrpcServer.cs +++ b/test/Grpc.IntegrationTests/GrpcSidecar/Grpc/TaskHubGrpcServer.cs @@ -336,6 +336,68 @@ static P.GetInstanceResponse CreateGetInstanceResponse(OrchestrationState state, return new P.ResumeResponse(); } + public override async Task RestartInstance(P.RestartInstanceRequest request, ServerCallContext context) + { + try + { + // Get the original orchestration state + IList states = await this.client.GetOrchestrationStateAsync(request.InstanceId, false); + + if (states == null || states.Count == 0) + { + throw new RpcException(new Status(StatusCode.NotFound, $"An orchestration with the instanceId {request.InstanceId} was not found.")); + } + + OrchestrationState state = states[0]; + + // Check if the state is null + if (state == null) + { + throw new RpcException(new Status(StatusCode.NotFound, $"An orchestration with the instanceId {request.InstanceId} was not found.")); + } + + string newInstanceId = request.RestartWithNewInstanceId ? Guid.NewGuid().ToString("N") : request.InstanceId; + + // Create a new orchestration instance + OrchestrationInstance newInstance = new() + { + InstanceId = newInstanceId, + ExecutionId = Guid.NewGuid().ToString("N"), + }; + + // Create an ExecutionStartedEvent with the original input + ExecutionStartedEvent startedEvent = new(-1, state.Input) + { + Name = state.Name, + Version = state.Version ?? string.Empty, + OrchestrationInstance = newInstance, + }; + + TaskMessage taskMessage = new() + { + OrchestrationInstance = newInstance, + Event = startedEvent, + }; + + await this.client.CreateTaskOrchestrationAsync(taskMessage); + + return new P.RestartInstanceResponse + { + InstanceId = newInstanceId, + }; + } + catch (RpcException) + { + // Re-throw RpcException as-is + throw; + } + catch (Exception ex) + { + this.log.LogError(ex, "Error restarting orchestration instance {InstanceId}", request.InstanceId); + throw new RpcException(new Status(StatusCode.Internal, ex.Message)); + } + } + static P.TaskFailureDetails? GetFailureDetails(FailureDetails? failureDetails) { if (failureDetails == null) diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/InMemoryOrchestrationService.cs b/test/Grpc.IntegrationTests/GrpcSidecar/InMemoryOrchestrationService.cs index 481fdc3ff..3494c64b2 100644 --- a/test/Grpc.IntegrationTests/GrpcSidecar/InMemoryOrchestrationService.cs +++ b/test/Grpc.IntegrationTests/GrpcSidecar/InMemoryOrchestrationService.cs @@ -444,8 +444,17 @@ public void AddMessage(TaskMessage message) SerializedInstanceState state = this.store.GetOrAdd(instanceId, id => new SerializedInstanceState(id, executionId)); lock (state) { + bool isRestart = state.ExecutionId != null && state.ExecutionId != executionId; + if (message.Event is ExecutionStartedEvent startEvent) { + // For restart scenarios, clear the history and reset the state + if (isRestart && state.IsCompleted) + { + state.ExecutionId = executionId; + state.IsLoaded = false; + } + OrchestrationState newStatusRecord = new() { OrchestrationInstance = startEvent.OrchestrationInstance, @@ -666,7 +675,10 @@ public void Schedule(SerializedInstanceState state) // to update the readyToRunQueue and the orchestration will get stuck. if (this.readyInstances.TryAdd(state.InstanceId, state)) { - this.readyToRunQueue.Writer.TryWrite(state); + if (!this.readyToRunQueue.Writer.TryWrite(state)) + { + throw new InvalidOperationException($"unable to write to queue for {state.InstanceId}"); + } } } }