Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/Client/Core/DurableTaskClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,47 @@ public virtual Task<PurgeResult> PurgeAllInstancesAsync(
throw new NotSupportedException($"{this.GetType()} does not support purging of orchestration instances.");
}

/// <summary>
/// Restarts an orchestration instance with the same or a new instance ID.
/// </summary>
/// <remarks>
/// <para>
/// This method restarts an existing orchestration instance. If <paramref name="restartWithNewInstanceId"/> is <c>true</c>,
/// a new instance ID will be generated for the restarted orchestration. If <c>false</c>, the original instance ID will be reused.
/// </para><para>
/// The restarted orchestration will use the same input data as the original instance. If the original orchestration
/// instance is not found, an <see cref="ArgumentException"/> will be thrown.
/// </para><para>
/// Note that this operation is backend-specific and may not be supported by all durable task backends.
/// If the backend does not support restart operations, a <see cref="NotSupportedException"/> will be thrown.
/// </para>
/// </remarks>
/// <param name="instanceId">The ID of the orchestration instance to restart.</param>
/// <param name="restartWithNewInstanceId">
/// If <c>true</c>, a new instance ID will be generated for the restarted orchestration.
/// If <c>false</c>, the original instance ID will be reused.
/// </param>
/// <param name="cancellation">
/// The cancellation token. This only cancels enqueueing the restart request to the backend.
/// Does not abort restarting the orchestration once enqueued.
/// </param>
/// <returns>
/// A task that completes when the orchestration instance is successfully restarted.
/// The value of this task is the instance ID of the restarted orchestration instance.
/// </returns>
/// <exception cref="ArgumentException">
/// Thrown if an orchestration with the specified <paramref name="instanceId"/> was not found. </exception>
/// <exception cref="InvalidOperationException">
/// Thrown when attempting to restart an instance using the same instance Id
/// while the instance has not yet reached a completed or terminal state. </exception>
/// <exception cref="NotSupportedException">
/// Thrown if the backend does not support restart operations. </exception>
public virtual Task<string> RestartAsync(
string instanceId,
bool restartWithNewInstanceId = false,
CancellationToken cancellation = default)
=> throw new NotSupportedException($"{this.GetType()} does not support orchestration restart.");

// TODO: Create task hub

// TODO: Delete task hub
Expand Down
36 changes: 36 additions & 0 deletions src/Client/Grpc/GrpcDurableTaskClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,42 @@ public override Task<PurgeResult> PurgeAllInstancesAsync(
return this.PurgeInstancesCoreAsync(request, cancellation);
}

/// <inheritdoc/>
public override async Task<string> RestartAsync(
string instanceId,
bool restartWithNewInstanceId = false,
CancellationToken cancellation = default)
{
Check.NotNullOrEmpty(instanceId);
Check.NotEntity(this.options.EnableEntitySupport, instanceId);

var request = new P.RestartInstanceRequest
{
InstanceId = instanceId,
RestartWithNewInstanceId = restartWithNewInstanceId,
};

try
{
P.RestartInstanceResponse result = await this.sidecarClient.RestartInstanceAsync(
request, cancellationToken: cancellation);
return result.InstanceId;
}
catch (RpcException e) when (e.StatusCode == StatusCode.NotFound)
{
throw new ArgumentException($"An orchestration with the instanceId {instanceId} was not found.", e);
}
catch (RpcException e) when (e.StatusCode == StatusCode.FailedPrecondition)
{
throw new InvalidOperationException($"An orchestration with the instanceId {instanceId} cannot be restarted.", e);
}
catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled)
{
throw new OperationCanceledException(
$"The {nameof(this.RestartAsync)} operation was canceled.", e, cancellation);
}
}

static AsyncDisposable GetCallInvoker(GrpcDurableTaskClientOptions options, out CallInvoker callInvoker)
{
if (options.Channel is GrpcChannel c)
Expand Down
53 changes: 53 additions & 0 deletions src/Client/OrchestrationServiceClientShim/ShimDurableTaskClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,59 @@ public override async Task<OrchestrationMetadata> WaitForInstanceStartAsync(
}
}

/// <inheritdoc/>
public override async Task<string> RestartAsync(
string instanceId,
bool restartWithNewInstanceId = false,
CancellationToken cancellation = default)
{
Check.NotNullOrEmpty(instanceId);
cancellation.ThrowIfCancellationRequested();

// Get the current orchestration status to retrieve the name and input
OrchestrationMetadata? status = await this.GetInstanceAsync(instanceId, getInputsAndOutputs: true, cancellation);

if (status == null)
{
throw new ArgumentException($"An orchestration with the instanceId {instanceId} was not found.");
}

bool isInstaceNotCompleted = status.RuntimeStatus == OrchestrationRuntimeStatus.Running ||
status.RuntimeStatus == OrchestrationRuntimeStatus.Pending ||
status.RuntimeStatus == OrchestrationRuntimeStatus.Suspended;

if (isInstaceNotCompleted && !restartWithNewInstanceId)
{
throw new InvalidOperationException($"Instance '{instanceId}' cannot be restarted while it is in state '{status.RuntimeStatus}'. " +
"Wait until it has completed, or restart with a new instance ID.");
}

// Determine the instance ID for the restarted orchestration
string newInstanceId = restartWithNewInstanceId ? Guid.NewGuid().ToString("N") : instanceId;

OrchestrationInstance instance = new()
{
InstanceId = newInstanceId,
ExecutionId = Guid.NewGuid().ToString("N"),
};

// Use the original serialized input directly to avoid double serialization
// TODO: OrchestrationMetada doesn't have version property so we don't support version here.
// Issue link: https://github.com/microsoft/durabletask-dotnet/issues/463
TaskMessage message = new()
{
OrchestrationInstance = instance,
Event = new ExecutionStartedEvent(-1, status.SerializedInput)
{
Name = status.Name,
OrchestrationInstance = instance,
},
};

await this.Client.CreateTaskOrchestrationAsync(message);
return newInstanceId;
}

[return: NotNullIfNotNull("state")]
OrchestrationMetadata? ToMetadata(Core.OrchestrationState? state, bool getInputsAndOutputs)
{
Expand Down
12 changes: 12 additions & 0 deletions src/Grpc/orchestrator_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,15 @@ message PurgeInstancesResponse {
google.protobuf.BoolValue isComplete = 2;
}

message RestartInstanceRequest {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be updated once we merge PR microsoft/durabletask-protobuf#46. But we will merge these related PR together, so I just manually updated it

string instanceId = 1;
bool restartWithNewInstanceId = 2;
}

message RestartInstanceResponse {
string instanceId = 1;
}

message CreateTaskHubRequest {
bool recreateIfExists = 1;
}
Expand Down Expand Up @@ -682,6 +691,9 @@ service TaskHubSidecarService {
// Rewinds an orchestration instance to last known good state and replays from there.
rpc RewindInstance(RewindInstanceRequest) returns (RewindInstanceResponse);

// Restarts an orchestration instance.
rpc RestartInstance(RestartInstanceRequest) returns (RestartInstanceResponse);

// Waits for an orchestration instance to reach a running or completion state.
rpc WaitForInstanceStart(GetInstanceRequest) returns (GetInstanceResponse);

Expand Down
4 changes: 2 additions & 2 deletions src/Grpc/versions.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# The following files were downloaded from branch main at 2025-08-08 16:46:11 UTC
https://raw.githubusercontent.com/microsoft/durabletask-protobuf/e88acbd07ae38b499dbe8c4e333e9e3feeb2a9cc/protos/orchestrator_service.proto
# The following files were downloaded from branch main at 2025-09-10 22:50:45 UTC
https://raw.githubusercontent.com/microsoft/durabletask-protobuf/985035a0890575ae18be0eb2a3ac93c10824498a/protos/orchestrator_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ public override Task<OrchestrationMetadata> WaitForInstanceStartAsync(
{
throw new NotImplementedException();
}

public override Task<string> RestartAsync(
string instanceId,
bool restartWithNewInstanceId = false,
CancellationToken cancellation = default)
{
throw new NotImplementedException();
}
}

class CustomDataConverter : DataConverter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,14 @@ public override Task<OrchestrationMetadata> WaitForInstanceStartAsync(
{
throw new NotImplementedException();
}

public override Task<string> RestartAsync(
string instanceId,
bool restartWithNewInstanceId = false,
CancellationToken cancellation = default)
{
throw new NotImplementedException();
}
}

class GoodBuildTargetOptions : DurableTaskClientOptions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,79 @@ public async Task ScheduleNewOrchestrationInstance_IdProvided_TagsProvided()
await this.RunScheduleNewOrchestrationInstanceAsync("test", "input", options);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task RestartAsync_EndToEnd(bool restartWithNewInstanceId)
{
string originalInstanceId = "test-instance-id";
string orchestratorName = "TestOrchestrator";
object input = "test-input";
string serializedInput = "\"test-input\"";

// Create a completed orchestration state
Core.OrchestrationState originalState = CreateState(input, "test-output");
originalState.OrchestrationInstance.InstanceId = originalInstanceId;
originalState.Name = orchestratorName;
originalState.OrchestrationStatus = Core.OrchestrationStatus.Completed;

// Setup the mock to return the original orchestration state
this.orchestrationClient
.Setup(x => x.GetOrchestrationStateAsync(originalInstanceId, false))
.ReturnsAsync(new List<Core.OrchestrationState> { originalState });

// Capture the TaskMessage for verification becasue we will create this message at RestartAsync.
TaskMessage? capturedMessage = null;
this.orchestrationClient
.Setup(x => x.CreateTaskOrchestrationAsync(It.IsAny<TaskMessage>()))
.Callback<TaskMessage>(msg => capturedMessage = msg)
.Returns(Task.CompletedTask);

string restartedInstanceId = await this.client.RestartAsync(originalInstanceId, restartWithNewInstanceId);

if (restartWithNewInstanceId)
{
restartedInstanceId.Should().NotBe(originalInstanceId);
}
else
{
restartedInstanceId.Should().Be(originalInstanceId);
}

// Verify that CreateTaskOrchestrationAsync was called
this.orchestrationClient.Verify(
x => x.CreateTaskOrchestrationAsync(It.IsAny<TaskMessage>()),
Times.Once);

// Verify the captured message details
capturedMessage.Should().NotBeNull();
capturedMessage!.Event.Should().BeOfType<ExecutionStartedEvent>();

var startedEvent = (ExecutionStartedEvent)capturedMessage.Event;
startedEvent.Name.Should().Be(orchestratorName);
startedEvent.Input.Should().Be(serializedInput);
// TODO: once we support version at ShimDurableTaskClient, we should check version here.
startedEvent.OrchestrationInstance.InstanceId.Should().Be(restartedInstanceId);
startedEvent.OrchestrationInstance.ExecutionId.Should().NotBeNullOrEmpty();
startedEvent.OrchestrationInstance.ExecutionId.Should().NotBe(originalState.OrchestrationInstance.ExecutionId);
}

[Fact]
public async Task RestartAsync_InstanceNotFound_ThrowsArgumentException()
{
string nonExistentInstanceId = "non-existent-instance-id";

// Setup the mock to client return empty orchestration state (instance not found)
this.orchestrationClient
.Setup(x => x.GetOrchestrationStateAsync(nonExistentInstanceId, false))
.ReturnsAsync(new List<Core.OrchestrationState>());

// RestartAsync should throw an ArgumentException since the instance is not found
Func<Task> restartAction = () => this.client.RestartAsync(nonExistentInstanceId);

await restartAction.Should().ThrowAsync<ArgumentException>()
.WithMessage($"*An orchestration with the instanceId {nonExistentInstanceId} was not found*");
}

static Core.OrchestrationState CreateState(
object input, object? output = null, DateTimeOffset start = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,70 @@ public async Task PurgeInstances_WithFilter_EndToEnd()
}
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task RestartAsync_EndToEnd(bool restartWithNewInstanceId)
{
using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1));
await using HostTestLifetime server = await this.StartAsync();

// Start an initial orchestration with shouldThrow = false to ensure it completes successfully
string originalInstanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(
OrchestrationName, input: false);

// Wait for it to start and then complete
await server.Client.WaitForInstanceStartAsync(originalInstanceId, default);
await server.Client.RaiseEventAsync(originalInstanceId, "event", default);
await server.Client.WaitForInstanceCompletionAsync(originalInstanceId, cts.Token);

// Verify the original orchestration completed
OrchestrationMetadata? originalMetadata = await server.Client.GetInstanceAsync(originalInstanceId, true);
originalMetadata.Should().NotBeNull();
originalMetadata!.RuntimeStatus.Should().Be(OrchestrationRuntimeStatus.Completed);

// Restart the orchestration
string restartedInstanceId = await server.Client.RestartAsync(originalInstanceId, restartWithNewInstanceId);

// Verify the restart behavior
if (restartWithNewInstanceId)
{
restartedInstanceId.Should().NotBe(originalInstanceId);
}
else
{
restartedInstanceId.Should().Be(originalInstanceId);
}

// Complete the restarted orchestration
await server.Client.RaiseEventAsync(restartedInstanceId, "event");

using var completionCts = new CancellationTokenSource(TimeSpan.FromSeconds(30));
await server.Client.WaitForInstanceCompletionAsync(restartedInstanceId, completionCts.Token);

// Verify the restarted orchestration completed.
// Also verify input and orchestrator name are matched.
var restartedMetadata = await server.Client.GetInstanceAsync(restartedInstanceId, true);
restartedMetadata.Should().NotBeNull();
restartedMetadata!.Name.Should().Be(OrchestrationName);
restartedMetadata.SerializedInput.Should().Be("false");
restartedMetadata!.RuntimeStatus.Should().Be(OrchestrationRuntimeStatus.Completed);
}

[Fact]
public async Task RestartAsync_InstanceNotFound_ThrowsArgumentException()
{
using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(1)); // 1-minute timeout
await using HostTestLifetime server = await this.StartAsync();

// Try to restart a non-existent orchestration
Func<Task> restartAction = () => server.Client.RestartAsync("non-existent-instance-id", cancellation: cts.Token);

// Should throw ArgumentException
await restartAction.Should().ThrowAsync<ArgumentException>()
.WithMessage("*An orchestration with the instanceId non-existent-instance-id was not found*");
}

Task<HostTestLifetime> StartAsync()
{
static async Task<string> Orchestration(TaskOrchestrationContext context, bool shouldThrow)
Expand Down
Loading
Loading