diff --git a/CHANGELOG.md b/CHANGELOG.md index f3ec402b8..02449a9e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - Introduce default version setting to DurableTaskClient and expose to orchestrator ([#393](https://github.com/microsoft/durabletask-dotnet/pull/393)) - Add support for local credential types in DTS libraries ([#396](https://github.com/microsoft/durabletask-dotnet/pull/396)) - Add utility for easier version comparison in orchestration context ([#394](https://github.com/microsoft/durabletask-dotnet/pull/394)) +- Add tags support for orchestrations ([#397])(https://github.com/microsoft/durabletask-dotnet/pull/397) ## v1.8.1 diff --git a/Directory.Packages.props b/Directory.Packages.props index a1e78379e..ee4e690d3 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -30,7 +30,6 @@ - @@ -77,6 +76,7 @@ + diff --git a/src/Abstractions/Abstractions.csproj b/src/Abstractions/Abstractions.csproj index e4c127b9e..db8be76ab 100644 --- a/src/Abstractions/Abstractions.csproj +++ b/src/Abstractions/Abstractions.csproj @@ -12,6 +12,7 @@ + diff --git a/src/Abstractions/TaskOptions.cs b/src/Abstractions/TaskOptions.cs index 63f943825..d3e06e5ca 100644 --- a/src/Abstractions/TaskOptions.cs +++ b/src/Abstractions/TaskOptions.cs @@ -1,103 +1,111 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -namespace Microsoft.DurableTask; - -/// -/// Options that can be used to control the behavior of orchestrator task execution. -/// -public record TaskOptions -{ - /// - /// Initializes a new instance of the class. - /// - /// The task retry options. - public TaskOptions(TaskRetryOptions? retry = null) - { - this.Retry = retry; - } - - /// - /// Gets the task retry options. - /// - public TaskRetryOptions? Retry { get; init; } - - /// - /// Returns a new from the provided . - /// - /// The policy to convert from. - /// A built from the policy. - public static TaskOptions FromRetryPolicy(RetryPolicy policy) => new(policy); - - /// - /// Returns a new from the provided . - /// - /// The handler to convert from. - /// A built from the handler. - public static TaskOptions FromRetryHandler(AsyncRetryHandler handler) => new(handler); - - /// - /// Returns a new from the provided . - /// - /// The handler to convert from. - /// A built from the handler. - public static TaskOptions FromRetryHandler(RetryHandler handler) => new(handler); - - /// - /// Returns a new with the provided instance ID. This can be used when - /// starting a new sub-orchestration to specify the instance ID. - /// - /// The instance ID to use. - /// A new . - public SubOrchestrationOptions WithInstanceId(string instanceId) => new(this, instanceId); -} - -/// -/// Options that can be used to control the behavior of orchestrator task execution. This derived type can be used to -/// supply extra options for orchestrations. -/// -public record SubOrchestrationOptions : TaskOptions -{ - /// - /// Initializes a new instance of the class. - /// - /// The task retry options. - /// The orchestration instance ID. - public SubOrchestrationOptions(TaskRetryOptions? retry = null, string? instanceId = null) - : base(retry) - { - this.InstanceId = instanceId; - } - - /// - /// Initializes a new instance of the class. - /// - /// The task options to wrap. - /// The orchestration instance ID. - public SubOrchestrationOptions(TaskOptions options, string? instanceId = null) - : base(options) - { - this.InstanceId = instanceId; - if (instanceId is null && options is SubOrchestrationOptions derived) - { - this.InstanceId = derived.InstanceId; - } - } - - /// - /// Gets the orchestration instance ID. - /// - public string? InstanceId { get; init; } -} - -/// -/// Options for submitting new orchestrations via the client. -/// -/// -/// The unique ID of the orchestration instance to schedule. If not specified, a new GUID value is used. -/// -/// -/// The time when the orchestration instance should start executing. If not specified or if a date-time in the past -/// is specified, the orchestration instance will be scheduled immediately. -/// -public record StartOrchestrationOptions(string? InstanceId = null, DateTimeOffset? StartAt = null); +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Immutable; + +namespace Microsoft.DurableTask; + +/// +/// Options that can be used to control the behavior of orchestrator task execution. +/// +public record TaskOptions +{ + /// + /// Initializes a new instance of the class. + /// + /// The task retry options. + public TaskOptions(TaskRetryOptions? retry = null) + { + this.Retry = retry; + } + + /// + /// Gets the task retry options. + /// + public TaskRetryOptions? Retry { get; init; } + + /// + /// Returns a new from the provided . + /// + /// The policy to convert from. + /// A built from the policy. + public static TaskOptions FromRetryPolicy(RetryPolicy policy) => new(policy); + + /// + /// Returns a new from the provided . + /// + /// The handler to convert from. + /// A built from the handler. + public static TaskOptions FromRetryHandler(AsyncRetryHandler handler) => new(handler); + + /// + /// Returns a new from the provided . + /// + /// The handler to convert from. + /// A built from the handler. + public static TaskOptions FromRetryHandler(RetryHandler handler) => new(handler); + + /// + /// Returns a new with the provided instance ID. This can be used when + /// starting a new sub-orchestration to specify the instance ID. + /// + /// The instance ID to use. + /// A new . + public SubOrchestrationOptions WithInstanceId(string instanceId) => new(this, instanceId); +} + +/// +/// Options that can be used to control the behavior of orchestrator task execution. This derived type can be used to +/// supply extra options for orchestrations. +/// +public record SubOrchestrationOptions : TaskOptions +{ + /// + /// Initializes a new instance of the class. + /// + /// The task retry options. + /// The orchestration instance ID. + public SubOrchestrationOptions(TaskRetryOptions? retry = null, string? instanceId = null) + : base(retry) + { + this.InstanceId = instanceId; + } + + /// + /// Initializes a new instance of the class. + /// + /// The task options to wrap. + /// The orchestration instance ID. + public SubOrchestrationOptions(TaskOptions options, string? instanceId = null) + : base(options) + { + this.InstanceId = instanceId; + if (instanceId is null && options is SubOrchestrationOptions derived) + { + this.InstanceId = derived.InstanceId; + } + } + + /// + /// Gets the orchestration instance ID. + /// + public string? InstanceId { get; init; } +} + +/// +/// Options for submitting new orchestrations via the client. +/// +/// +/// The unique ID of the orchestration instance to schedule. If not specified, a new GUID value is used. +/// +/// +/// The time when the orchestration instance should start executing. If not specified or if a date-time in the past +/// is specified, the orchestration instance will be scheduled immediately. +/// +public record StartOrchestrationOptions(string? InstanceId = null, DateTimeOffset? StartAt = null) +{ + /// + /// Gets the tags to associate with the orchestration instance. + /// + public IReadOnlyDictionary Tags { get; init; } = ImmutableDictionary.Create(); +} diff --git a/src/Client/Core/OrchestrationMetadata.cs b/src/Client/Core/OrchestrationMetadata.cs index 3fc43b0a3..a1cf3d9fa 100644 --- a/src/Client/Core/OrchestrationMetadata.cs +++ b/src/Client/Core/OrchestrationMetadata.cs @@ -1,233 +1,239 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Diagnostics.CodeAnalysis; -using System.Text; - -namespace Microsoft.DurableTask.Client; - -/// -/// Represents a snapshot of an orchestration instance's current state, including metadata. -/// -/// -/// Instances of this class are produced by methods in the class, such as -/// , -/// and -/// . -/// -public sealed class OrchestrationMetadata -{ - /// - /// Initializes a new instance of the class. - /// - /// The name of the orchestration. - /// The instance ID of the orchestration. - public OrchestrationMetadata(string name, string instanceId) - { - this.Name = name; - this.InstanceId = instanceId; - } - - /// Gets the name of the orchestration. - /// The name of the orchestration. - public string Name { get; } - - /// Gets the unique ID of the orchestration instance. - /// The unique ID of the orchestration instance. - public string InstanceId { get; } - - /// - /// Gets the data converter used to deserialized the serialized data on this instance. - /// This will only be present when inputs and outputs are requested, null otherwise. - /// - /// The optional data converter. - public DataConverter? DataConverter { get; init; } - - /// - /// Gets the current runtime status of the orchestration instance at the time this object was fetched. - /// - /// The runtime status of the orchestration instance at the time this object was fetched. - public OrchestrationRuntimeStatus RuntimeStatus { get; init; } - - /// - /// Gets the orchestration instance's creation time in UTC. - /// - /// The orchestration instance's creation time in UTC. - public DateTimeOffset CreatedAt { get; init; } - - /// - /// Gets the orchestration instance's last updated time in UTC. - /// - /// The orchestration instance's last updated time in UTC. - public DateTimeOffset LastUpdatedAt { get; init; } - - /// - /// Gets the orchestration instance's serialized input, if any, as a string value. - /// - /// The serialized orchestration input or null. - public string? SerializedInput { get; init; } - - /// - /// Gets the orchestration instance's serialized output, if any, as a string value. - /// - /// The serialized orchestration output or null. - public string? SerializedOutput { get; init; } - - /// - /// Gets the orchestration instance's serialized custom status, if any, as a string value. - /// - /// The serialized custom status or null. - public string? SerializedCustomStatus { get; init; } - - /// - /// Gets the failure details, if any, for the orchestration instance. - /// - /// - /// This property contains data only if the orchestration is in the - /// state, and only if this instance metadata was fetched with the option to include output data. - /// - /// The failure details if the orchestration was in a failed state; null otherwise. - public TaskFailureDetails? FailureDetails { get; init; } - - /// - /// Gets a value indicating whether the orchestration instance was running at the time this object was fetched. - /// - /// true if the orchestration was in a running state; false otherwise. - public bool IsRunning => this.RuntimeStatus == OrchestrationRuntimeStatus.Running; - - /// - /// Gets a value indicating whether the orchestration instance was completed at the time this object was fetched. - /// - /// - /// An orchestration instance is considered completed when its value is - /// , , - /// or . - /// - /// true if the orchestration was in a terminal state; false otherwise. - public bool IsCompleted => - this.RuntimeStatus == OrchestrationRuntimeStatus.Completed || - this.RuntimeStatus == OrchestrationRuntimeStatus.Failed || - this.RuntimeStatus == OrchestrationRuntimeStatus.Terminated; - - [MemberNotNullWhen(true, nameof(DataConverter))] - bool RequestedInputsAndOutputs => this.DataConverter is not null; - - /// - /// Deserializes the orchestration's input into an object of the specified type. - /// - /// - /// This method can only be used when inputs and outputs are explicitly requested from the - /// or - /// method that produced - /// this object. - /// - /// The type to deserialize the orchestration input into. - /// Returns the deserialized input value. - /// - /// Thrown if this metadata object was fetched without the option to read inputs and outputs. - /// - public T? ReadInputAs() - { - if (!this.RequestedInputsAndOutputs) - { - throw new InvalidOperationException( - $"The {nameof(this.ReadInputAs)} method can only be used on {nameof(OrchestrationMetadata)} objects " + - "that are fetched with the option to include input data."); - } - - return this.DataConverter.Deserialize(this.SerializedInput); - } - - /// - /// Deserializes the orchestration's output into an object of the specified type. - /// - /// - /// This method can only be used when inputs and outputs are explicitly requested from the - /// or - /// method that produced - /// this object. - /// - /// The type to deserialize the orchestration output into. - /// Returns the deserialized output value. - /// - /// Thrown if this metadata object was fetched without the option to read inputs and outputs. - /// - public T? ReadOutputAs() - { - if (!this.RequestedInputsAndOutputs) - { - throw new InvalidOperationException( - $"The {nameof(this.ReadOutputAs)} method can only be used on {nameof(OrchestrationMetadata)} objects " + - "that are fetched with the option to include output data."); - } - - return this.DataConverter.Deserialize(this.SerializedOutput); - } - - /// - /// Deserializes the orchestration's custom status value into an object of the specified type. - /// - /// - /// This method can only be used when inputs and outputs are explicitly requested from the - /// or - /// method that produced - /// this object. - /// - /// The type to deserialize the orchestration' custom status into. - /// Returns the deserialized custom status value. - /// - /// Thrown if this metadata object was fetched without the option to read inputs and outputs. - /// - public T? ReadCustomStatusAs() - { - if (!this.RequestedInputsAndOutputs) - { - throw new InvalidOperationException( - $"The {nameof(this.ReadCustomStatusAs)} method can only be used on {nameof(OrchestrationMetadata)}" - + " objects that are fetched with the option to include input and output data."); - } - - return this.DataConverter.Deserialize(this.SerializedCustomStatus); - } - - /// - /// Generates a user-friendly string representation of the current metadata object. - /// - /// A user-friendly string representation of the current metadata object. - public override string ToString() - { - StringBuilder sb = new($"[Name: '{this.Name}', ID: '{this.InstanceId}', RuntimeStatus: {this.RuntimeStatus}," - + $" CreatedAt: {this.CreatedAt:s}, LastUpdatedAt: {this.LastUpdatedAt:s}"); - if (this.SerializedInput != null) - { - sb.Append(", Input: '").Append(GetTrimmedPayload(this.SerializedInput)).Append('\''); - } - - if (this.SerializedOutput != null) - { - sb.Append(", Output: '").Append(GetTrimmedPayload(this.SerializedOutput)).Append('\''); - } - - if (this.FailureDetails != null) - { - sb.Append(", FailureDetails: '") - .Append(this.FailureDetails.ErrorType) - .Append(" - ") - .Append(GetTrimmedPayload(this.FailureDetails.ErrorMessage)) - .Append('\''); - } - - return sb.Append(']').ToString(); - } - - static string GetTrimmedPayload(string payload) - { - const int MaxLength = 50; - if (payload.Length > MaxLength) - { - return string.Concat(payload[..MaxLength], "..."); - } - - return payload; - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.Text; + +namespace Microsoft.DurableTask.Client; + +/// +/// Represents a snapshot of an orchestration instance's current state, including metadata. +/// +/// +/// Instances of this class are produced by methods in the class, such as +/// , +/// and +/// . +/// +public sealed class OrchestrationMetadata +{ + /// + /// Initializes a new instance of the class. + /// + /// The name of the orchestration. + /// The instance ID of the orchestration. + public OrchestrationMetadata(string name, string instanceId) + { + this.Name = name; + this.InstanceId = instanceId; + } + + /// Gets the name of the orchestration. + /// The name of the orchestration. + public string Name { get; } + + /// Gets the unique ID of the orchestration instance. + /// The unique ID of the orchestration instance. + public string InstanceId { get; } + + /// + /// Gets the data converter used to deserialized the serialized data on this instance. + /// This will only be present when inputs and outputs are requested, null otherwise. + /// + /// The optional data converter. + public DataConverter? DataConverter { get; init; } + + /// + /// Gets the current runtime status of the orchestration instance at the time this object was fetched. + /// + /// The runtime status of the orchestration instance at the time this object was fetched. + public OrchestrationRuntimeStatus RuntimeStatus { get; init; } + + /// + /// Gets the orchestration instance's creation time in UTC. + /// + /// The orchestration instance's creation time in UTC. + public DateTimeOffset CreatedAt { get; init; } + + /// + /// Gets the orchestration instance's last updated time in UTC. + /// + /// The orchestration instance's last updated time in UTC. + public DateTimeOffset LastUpdatedAt { get; init; } + + /// + /// Gets the orchestration instance's serialized input, if any, as a string value. + /// + /// The serialized orchestration input or null. + public string? SerializedInput { get; init; } + + /// + /// Gets the orchestration instance's serialized output, if any, as a string value. + /// + /// The serialized orchestration output or null. + public string? SerializedOutput { get; init; } + + /// + /// Gets the orchestration instance's serialized custom status, if any, as a string value. + /// + /// The serialized custom status or null. + public string? SerializedCustomStatus { get; init; } + + /// + /// Gets the tags associated with the orchestration instance. + /// + public IReadOnlyDictionary Tags { get; init; } = ImmutableDictionary.Create(); + + /// + /// Gets the failure details, if any, for the orchestration instance. + /// + /// + /// This property contains data only if the orchestration is in the + /// state, and only if this instance metadata was fetched with the option to include output data. + /// + /// The failure details if the orchestration was in a failed state; null otherwise. + public TaskFailureDetails? FailureDetails { get; init; } + + /// + /// Gets a value indicating whether the orchestration instance was running at the time this object was fetched. + /// + /// true if the orchestration was in a running state; false otherwise. + public bool IsRunning => this.RuntimeStatus == OrchestrationRuntimeStatus.Running; + + /// + /// Gets a value indicating whether the orchestration instance was completed at the time this object was fetched. + /// + /// + /// An orchestration instance is considered completed when its value is + /// , , + /// or . + /// + /// true if the orchestration was in a terminal state; false otherwise. + public bool IsCompleted => + this.RuntimeStatus == OrchestrationRuntimeStatus.Completed || + this.RuntimeStatus == OrchestrationRuntimeStatus.Failed || + this.RuntimeStatus == OrchestrationRuntimeStatus.Terminated; + + [MemberNotNullWhen(true, nameof(DataConverter))] + bool RequestedInputsAndOutputs => this.DataConverter is not null; + + /// + /// Deserializes the orchestration's input into an object of the specified type. + /// + /// + /// This method can only be used when inputs and outputs are explicitly requested from the + /// or + /// method that produced + /// this object. + /// + /// The type to deserialize the orchestration input into. + /// Returns the deserialized input value. + /// + /// Thrown if this metadata object was fetched without the option to read inputs and outputs. + /// + public T? ReadInputAs() + { + if (!this.RequestedInputsAndOutputs) + { + throw new InvalidOperationException( + $"The {nameof(this.ReadInputAs)} method can only be used on {nameof(OrchestrationMetadata)} objects " + + "that are fetched with the option to include input data."); + } + + return this.DataConverter.Deserialize(this.SerializedInput); + } + + /// + /// Deserializes the orchestration's output into an object of the specified type. + /// + /// + /// This method can only be used when inputs and outputs are explicitly requested from the + /// or + /// method that produced + /// this object. + /// + /// The type to deserialize the orchestration output into. + /// Returns the deserialized output value. + /// + /// Thrown if this metadata object was fetched without the option to read inputs and outputs. + /// + public T? ReadOutputAs() + { + if (!this.RequestedInputsAndOutputs) + { + throw new InvalidOperationException( + $"The {nameof(this.ReadOutputAs)} method can only be used on {nameof(OrchestrationMetadata)} objects " + + "that are fetched with the option to include output data."); + } + + return this.DataConverter.Deserialize(this.SerializedOutput); + } + + /// + /// Deserializes the orchestration's custom status value into an object of the specified type. + /// + /// + /// This method can only be used when inputs and outputs are explicitly requested from the + /// or + /// method that produced + /// this object. + /// + /// The type to deserialize the orchestration' custom status into. + /// Returns the deserialized custom status value. + /// + /// Thrown if this metadata object was fetched without the option to read inputs and outputs. + /// + public T? ReadCustomStatusAs() + { + if (!this.RequestedInputsAndOutputs) + { + throw new InvalidOperationException( + $"The {nameof(this.ReadCustomStatusAs)} method can only be used on {nameof(OrchestrationMetadata)}" + + " objects that are fetched with the option to include input and output data."); + } + + return this.DataConverter.Deserialize(this.SerializedCustomStatus); + } + + /// + /// Generates a user-friendly string representation of the current metadata object. + /// + /// A user-friendly string representation of the current metadata object. + public override string ToString() + { + StringBuilder sb = new($"[Name: '{this.Name}', ID: '{this.InstanceId}', RuntimeStatus: {this.RuntimeStatus}," + + $" CreatedAt: {this.CreatedAt:s}, LastUpdatedAt: {this.LastUpdatedAt:s}"); + if (this.SerializedInput != null) + { + sb.Append(", Input: '").Append(GetTrimmedPayload(this.SerializedInput)).Append('\''); + } + + if (this.SerializedOutput != null) + { + sb.Append(", Output: '").Append(GetTrimmedPayload(this.SerializedOutput)).Append('\''); + } + + if (this.FailureDetails != null) + { + sb.Append(", FailureDetails: '") + .Append(this.FailureDetails.ErrorType) + .Append(" - ") + .Append(GetTrimmedPayload(this.FailureDetails.ErrorMessage)) + .Append('\''); + } + + return sb.Append(']').ToString(); + } + + static string GetTrimmedPayload(string payload) + { + const int MaxLength = 50; + if (payload.Length > MaxLength) + { + return string.Concat(payload[..MaxLength], "..."); + } + + return payload; + } +} diff --git a/src/Client/Grpc/GrpcDurableTaskClient.cs b/src/Client/Grpc/GrpcDurableTaskClient.cs index e7fd1da92..32c95b64b 100644 --- a/src/Client/Grpc/GrpcDurableTaskClient.cs +++ b/src/Client/Grpc/GrpcDurableTaskClient.cs @@ -1,463 +1,475 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Diagnostics; -using System.Text; -using Google.Protobuf.WellKnownTypes; -using Microsoft.DurableTask.Client.Entities; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; -using static Microsoft.DurableTask.Protobuf.TaskHubSidecarService; -using P = Microsoft.DurableTask.Protobuf; - -namespace Microsoft.DurableTask.Client.Grpc; - -/// -/// Durable Task client implementation that uses gRPC to connect to a remote "sidecar" process. -/// -public sealed class GrpcDurableTaskClient : DurableTaskClient -{ - readonly ILogger logger; - readonly TaskHubSidecarServiceClient sidecarClient; - readonly GrpcDurableTaskClientOptions options; - readonly DurableEntityClient? entityClient; - AsyncDisposable asyncDisposable; - - /// - /// Initializes a new instance of the class. - /// - /// The name of the client. - /// The gRPC client options. - /// The logger. - [ActivatorUtilitiesConstructor] - public GrpcDurableTaskClient( - string name, IOptionsMonitor options, ILogger logger) - : this(name, Check.NotNull(options).Get(name), logger) - { - } - - /// - /// Initializes a new instance of the class. - /// - /// The name of the client. - /// The gRPC client options. - /// The logger. - public GrpcDurableTaskClient(string name, GrpcDurableTaskClientOptions options, ILogger logger) - : base(name) - { - this.logger = Check.NotNull(logger); - this.options = Check.NotNull(options); - this.asyncDisposable = GetCallInvoker(options, out CallInvoker callInvoker); - this.sidecarClient = new TaskHubSidecarServiceClient(callInvoker); - - if (this.options.EnableEntitySupport) - { - this.entityClient = new GrpcDurableEntityClient(this.Name, this.DataConverter, this.sidecarClient, logger); - } - } - - /// - public override DurableEntityClient Entities => this.entityClient - ?? throw new NotSupportedException($"Durable entities are disabled because {nameof(DurableTaskClientOptions)}.{nameof(DurableTaskClientOptions.EnableEntitySupport)}=false"); - - DataConverter DataConverter => this.options.DataConverter; - - /// - public override ValueTask DisposeAsync() - { - return this.asyncDisposable.DisposeAsync(); - } - - /// - public override async Task ScheduleNewOrchestrationInstanceAsync( - TaskName orchestratorName, - object? input = null, - StartOrchestrationOptions? options = null, - CancellationToken cancellation = default) - { - Check.NotEntity(this.options.EnableEntitySupport, options?.InstanceId); - - string version = string.Empty; - if (!string.IsNullOrEmpty(orchestratorName.Version)) - { - version = orchestratorName.Version; - } - else if (!string.IsNullOrEmpty(this.options.DefaultVersion)) - { - version = this.options.DefaultVersion; - } - - var request = new P.CreateInstanceRequest - { - Name = orchestratorName.Name, - Version = version, - InstanceId = options?.InstanceId ?? Guid.NewGuid().ToString("N"), - Input = this.DataConverter.Serialize(input), - }; - - if (Activity.Current?.Id != null || Activity.Current?.TraceStateString != null) - { - if (request.ParentTraceContext == null) - { - request.ParentTraceContext = new P.TraceContext(); - } - - if (Activity.Current?.Id != null) - { - request.ParentTraceContext.TraceParent = Activity.Current?.Id; - } - - if (Activity.Current?.TraceStateString != null) - { - request.ParentTraceContext.TraceState = Activity.Current?.TraceStateString; - } - } - - DateTimeOffset? startAt = options?.StartAt; - this.logger.SchedulingOrchestration( - request.InstanceId, - orchestratorName, - sizeInBytes: request.Input != null ? Encoding.UTF8.GetByteCount(request.Input) : 0, - startAt.GetValueOrDefault(DateTimeOffset.UtcNow)); - - if (startAt.HasValue) - { - // Convert timestamps to UTC if not already UTC - request.ScheduledStartTimestamp = Timestamp.FromDateTimeOffset(startAt.Value.ToUniversalTime()); - } - - P.CreateInstanceResponse? result = await this.sidecarClient.StartInstanceAsync( - request, cancellationToken: cancellation); - return result.InstanceId; - } - - /// - public override async Task RaiseEventAsync( - string instanceId, string eventName, object? eventPayload = null, CancellationToken cancellation = default) - { - Check.NotNullOrEmpty(instanceId); - Check.NotNullOrEmpty(eventName); - - Check.NotEntity(this.options.EnableEntitySupport, instanceId); - - P.RaiseEventRequest request = new() - { - InstanceId = instanceId, - Name = eventName, - Input = this.DataConverter.Serialize(eventPayload), - }; - - await this.sidecarClient.RaiseEventAsync(request, cancellationToken: cancellation); - } - - /// - public override async Task TerminateInstanceAsync( - string instanceId, TerminateInstanceOptions? options = null, CancellationToken cancellation = default) - { - object? output = options?.Output; - bool recursive = options?.Recursive ?? false; - - Check.NotNullOrEmpty(instanceId); - Check.NotEntity(this.options.EnableEntitySupport, instanceId); - - this.logger.TerminatingInstance(instanceId); - - string? serializedOutput = this.DataConverter.Serialize(output); - await this.sidecarClient.TerminateInstanceAsync( - new P.TerminateRequest - { - InstanceId = instanceId, - Output = serializedOutput, - Recursive = recursive, - }, - cancellationToken: cancellation); - } - - /// - public override async Task SuspendInstanceAsync( - string instanceId, string? reason = null, CancellationToken cancellation = default) - { - Check.NotEntity(this.options.EnableEntitySupport, instanceId); - - P.SuspendRequest request = new() - { - InstanceId = instanceId, - Reason = reason, - }; - - try - { - await this.sidecarClient.SuspendInstanceAsync(request, cancellationToken: cancellation); - } - catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled) - { - throw new OperationCanceledException( - $"The {nameof(this.SuspendInstanceAsync)} operation was canceled.", e, cancellation); - } - } - - /// - public override async Task ResumeInstanceAsync( - string instanceId, string? reason = null, CancellationToken cancellation = default) - { - Check.NotEntity(this.options.EnableEntitySupport, instanceId); - - P.ResumeRequest request = new() - { - InstanceId = instanceId, - Reason = reason, - }; - - try - { - await this.sidecarClient.ResumeInstanceAsync(request, cancellationToken: cancellation); - } - catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled) - { - throw new OperationCanceledException( - $"The {nameof(this.ResumeInstanceAsync)} operation was canceled.", e, cancellation); - } - } - - /// - public override async Task GetInstancesAsync( - string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) - { - Check.NotEntity(this.options.EnableEntitySupport, instanceId); - - if (string.IsNullOrEmpty(instanceId)) - { - throw new ArgumentNullException(nameof(instanceId)); - } - - P.GetInstanceResponse response = await this.sidecarClient.GetInstanceAsync( - new P.GetInstanceRequest - { - InstanceId = instanceId, - GetInputsAndOutputs = getInputsAndOutputs, - }, - cancellationToken: cancellation); - - // REVIEW: Should we return a non-null value instead of !exists? - if (!response.Exists) - { - return null; - } - - return this.CreateMetadata(response.OrchestrationState, getInputsAndOutputs); - } - - /// - public override AsyncPageable GetAllInstancesAsync(OrchestrationQuery? filter = null) - { - Check.NotEntity(this.options.EnableEntitySupport, filter?.InstanceIdPrefix); - - return Pageable.Create(async (continuation, pageSize, cancellation) => - { - P.QueryInstancesRequest request = new() - { - Query = new P.InstanceQuery - { - CreatedTimeFrom = filter?.CreatedFrom?.ToTimestamp(), - CreatedTimeTo = filter?.CreatedTo?.ToTimestamp(), - FetchInputsAndOutputs = filter?.FetchInputsAndOutputs ?? false, - InstanceIdPrefix = filter?.InstanceIdPrefix, - MaxInstanceCount = pageSize ?? filter?.PageSize ?? OrchestrationQuery.DefaultPageSize, - ContinuationToken = continuation ?? filter?.ContinuationToken, - }, - }; - - if (filter?.Statuses is not null) - { - request.Query.RuntimeStatus.AddRange(filter.Statuses.Select(x => x.ToGrpcStatus())); - } - - if (filter?.TaskHubNames is not null) - { - request.Query.TaskHubNames.AddRange(filter.TaskHubNames); - } - - try - { - P.QueryInstancesResponse response = await this.sidecarClient.QueryInstancesAsync( - request, cancellationToken: cancellation); - - bool getInputsAndOutputs = filter?.FetchInputsAndOutputs ?? false; - IReadOnlyList values = response.OrchestrationState - .Select(x => this.CreateMetadata(x, getInputsAndOutputs)) - .ToList(); - - return new Page(values, response.ContinuationToken); - } - catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled) - { - throw new OperationCanceledException( - $"The {nameof(this.GetInstancesAsync)} operation was canceled.", e, cancellation); - } - }); - } - - /// - public override async Task WaitForInstanceStartAsync( - string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) - { - Check.NotEntity(this.options.EnableEntitySupport, instanceId); - - this.logger.WaitingForInstanceStart(instanceId, getInputsAndOutputs); - - P.GetInstanceRequest request = new() - { - InstanceId = instanceId, - GetInputsAndOutputs = getInputsAndOutputs, - }; - - try - { - P.GetInstanceResponse response = await this.sidecarClient.WaitForInstanceStartAsync( - request, cancellationToken: cancellation); - return this.CreateMetadata(response.OrchestrationState, getInputsAndOutputs); - } - catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled) - { - throw new OperationCanceledException( - $"The {nameof(this.WaitForInstanceStartAsync)} operation was canceled.", e, cancellation); - } - } - - /// - public override async Task WaitForInstanceCompletionAsync( - string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) - { - Check.NotEntity(this.options.EnableEntitySupport, instanceId); - - this.logger.WaitingForInstanceCompletion(instanceId, getInputsAndOutputs); - - P.GetInstanceRequest request = new() - { - InstanceId = instanceId, - GetInputsAndOutputs = getInputsAndOutputs, - }; - - try - { - P.GetInstanceResponse response = await this.sidecarClient.WaitForInstanceCompletionAsync( - request, cancellationToken: cancellation); - return this.CreateMetadata(response.OrchestrationState, getInputsAndOutputs); - } - catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled) - { - throw new OperationCanceledException( - $"The {nameof(this.WaitForInstanceCompletionAsync)} operation was canceled.", e, cancellation); - } - } - - /// - public override Task PurgeInstanceAsync( - string instanceId, PurgeInstanceOptions? options = null, CancellationToken cancellation = default) - { - bool recursive = options?.Recursive ?? false; - this.logger.PurgingInstanceMetadata(instanceId); - - P.PurgeInstancesRequest request = new() { InstanceId = instanceId, Recursive = recursive }; - return this.PurgeInstancesCoreAsync(request, cancellation); - } - - /// - public override Task PurgeAllInstancesAsync( - PurgeInstancesFilter filter, PurgeInstanceOptions? options = null, CancellationToken cancellation = default) - { - bool recursive = options?.Recursive ?? false; - this.logger.PurgingInstances(filter); - P.PurgeInstancesRequest request = new() - { - PurgeInstanceFilter = new() - { - CreatedTimeFrom = filter?.CreatedFrom.ToTimestamp(), - CreatedTimeTo = filter?.CreatedTo.ToTimestamp(), - }, - Recursive = recursive, - }; - - if (filter?.Statuses is not null) - { - request.PurgeInstanceFilter.RuntimeStatus.AddRange(filter.Statuses.Select(x => x.ToGrpcStatus())); - } - - return this.PurgeInstancesCoreAsync(request, cancellation); - } - - static AsyncDisposable GetCallInvoker(GrpcDurableTaskClientOptions options, out CallInvoker callInvoker) - { - if (options.Channel is GrpcChannel c) - { - callInvoker = c.CreateCallInvoker(); - return default; - } - - if (options.CallInvoker is CallInvoker invoker) - { - callInvoker = invoker; - return default; - } - - c = GetChannel(options.Address); - callInvoker = c.CreateCallInvoker(); - return new AsyncDisposable(() => new(c.ShutdownAsync())); - } - -#if NET6_0_OR_GREATER - static GrpcChannel GetChannel(string? address) - { - if (string.IsNullOrEmpty(address)) - { - address = "http://localhost:4001"; - } - - return GrpcChannel.ForAddress(address); - } -#endif - -#if NETSTANDARD2_0 - static GrpcChannel GetChannel(string? address) - { - if (string.IsNullOrEmpty(address)) - { - address = "localhost:4001"; - } - - return new(address, ChannelCredentials.Insecure); - } -#endif - - async Task PurgeInstancesCoreAsync( - P.PurgeInstancesRequest request, CancellationToken cancellation = default) - { - try - { - P.PurgeInstancesResponse response = await this.sidecarClient.PurgeInstancesAsync( - request, cancellationToken: cancellation); - return new PurgeResult(response.DeletedInstanceCount); - } - catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled) - { - throw new OperationCanceledException( - $"The {nameof(this.PurgeAllInstancesAsync)} operation was canceled.", e, cancellation); - } - } - - OrchestrationMetadata CreateMetadata(P.OrchestrationState state, bool includeInputsAndOutputs) - { - return new(state.Name, state.InstanceId) - { - CreatedAt = state.CreatedTimestamp.ToDateTimeOffset(), - LastUpdatedAt = state.LastUpdatedTimestamp.ToDateTimeOffset(), - RuntimeStatus = (OrchestrationRuntimeStatus)state.OrchestrationStatus, - SerializedInput = state.Input, - SerializedOutput = state.Output, - SerializedCustomStatus = state.CustomStatus, - FailureDetails = state.FailureDetails.ToTaskFailureDetails(), - DataConverter = includeInputsAndOutputs ? this.DataConverter : null, - }; - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Text; +using Google.Protobuf.WellKnownTypes; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using static Microsoft.DurableTask.Protobuf.TaskHubSidecarService; +using P = Microsoft.DurableTask.Protobuf; + +namespace Microsoft.DurableTask.Client.Grpc; + +/// +/// Durable Task client implementation that uses gRPC to connect to a remote "sidecar" process. +/// +public sealed class GrpcDurableTaskClient : DurableTaskClient +{ + readonly ILogger logger; + readonly TaskHubSidecarServiceClient sidecarClient; + readonly GrpcDurableTaskClientOptions options; + readonly DurableEntityClient? entityClient; + AsyncDisposable asyncDisposable; + + /// + /// Initializes a new instance of the class. + /// + /// The name of the client. + /// The gRPC client options. + /// The logger. + [ActivatorUtilitiesConstructor] + public GrpcDurableTaskClient( + string name, IOptionsMonitor options, ILogger logger) + : this(name, Check.NotNull(options).Get(name), logger) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The name of the client. + /// The gRPC client options. + /// The logger. + public GrpcDurableTaskClient(string name, GrpcDurableTaskClientOptions options, ILogger logger) + : base(name) + { + this.logger = Check.NotNull(logger); + this.options = Check.NotNull(options); + this.asyncDisposable = GetCallInvoker(options, out CallInvoker callInvoker); + this.sidecarClient = new TaskHubSidecarServiceClient(callInvoker); + + if (this.options.EnableEntitySupport) + { + this.entityClient = new GrpcDurableEntityClient(this.Name, this.DataConverter, this.sidecarClient, logger); + } + } + + /// + public override DurableEntityClient Entities => this.entityClient + ?? throw new NotSupportedException($"Durable entities are disabled because {nameof(DurableTaskClientOptions)}.{nameof(DurableTaskClientOptions.EnableEntitySupport)}=false"); + + DataConverter DataConverter => this.options.DataConverter; + + /// + public override ValueTask DisposeAsync() + { + return this.asyncDisposable.DisposeAsync(); + } + + /// + public override async Task ScheduleNewOrchestrationInstanceAsync( + TaskName orchestratorName, + object? input = null, + StartOrchestrationOptions? options = null, + CancellationToken cancellation = default) + { + Check.NotEntity(this.options.EnableEntitySupport, options?.InstanceId); + + string version = string.Empty; + if (!string.IsNullOrEmpty(orchestratorName.Version)) + { + version = orchestratorName.Version; + } + else if (!string.IsNullOrEmpty(this.options.DefaultVersion)) + { + version = this.options.DefaultVersion; + } + + var request = new P.CreateInstanceRequest + { + Name = orchestratorName.Name, + Version = version, + InstanceId = options?.InstanceId ?? Guid.NewGuid().ToString("N"), + Input = this.DataConverter.Serialize(input), + }; + + // Add tags to the collection + if (request?.Tags != null && options?.Tags != null) + { + foreach (KeyValuePair tag in options.Tags) + { + request.Tags.Add(tag.Key, tag.Value); + } + } + + if (Activity.Current?.Id != null || Activity.Current?.TraceStateString != null) + { + if (request.ParentTraceContext == null) + { + request.ParentTraceContext = new P.TraceContext(); + } + + if (Activity.Current?.Id != null) + { + request.ParentTraceContext.TraceParent = Activity.Current?.Id; + } + + if (Activity.Current?.TraceStateString != null) + { + request.ParentTraceContext.TraceState = Activity.Current?.TraceStateString; + } + } + + DateTimeOffset? startAt = options?.StartAt; + this.logger.SchedulingOrchestration( + request.InstanceId, + orchestratorName, + sizeInBytes: request.Input != null ? Encoding.UTF8.GetByteCount(request.Input) : 0, + startAt.GetValueOrDefault(DateTimeOffset.UtcNow)); + + if (startAt.HasValue) + { + // Convert timestamps to UTC if not already UTC + request.ScheduledStartTimestamp = Timestamp.FromDateTimeOffset(startAt.Value.ToUniversalTime()); + } + + P.CreateInstanceResponse? result = await this.sidecarClient.StartInstanceAsync( + request, cancellationToken: cancellation); + return result.InstanceId; + } + + /// + public override async Task RaiseEventAsync( + string instanceId, string eventName, object? eventPayload = null, CancellationToken cancellation = default) + { + Check.NotNullOrEmpty(instanceId); + Check.NotNullOrEmpty(eventName); + + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + + P.RaiseEventRequest request = new() + { + InstanceId = instanceId, + Name = eventName, + Input = this.DataConverter.Serialize(eventPayload), + }; + + await this.sidecarClient.RaiseEventAsync(request, cancellationToken: cancellation); + } + + /// + public override async Task TerminateInstanceAsync( + string instanceId, TerminateInstanceOptions? options = null, CancellationToken cancellation = default) + { + object? output = options?.Output; + bool recursive = options?.Recursive ?? false; + + Check.NotNullOrEmpty(instanceId); + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + + this.logger.TerminatingInstance(instanceId); + + string? serializedOutput = this.DataConverter.Serialize(output); + await this.sidecarClient.TerminateInstanceAsync( + new P.TerminateRequest + { + InstanceId = instanceId, + Output = serializedOutput, + Recursive = recursive, + }, + cancellationToken: cancellation); + } + + /// + public override async Task SuspendInstanceAsync( + string instanceId, string? reason = null, CancellationToken cancellation = default) + { + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + + P.SuspendRequest request = new() + { + InstanceId = instanceId, + Reason = reason, + }; + + try + { + await this.sidecarClient.SuspendInstanceAsync(request, cancellationToken: cancellation); + } + catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled) + { + throw new OperationCanceledException( + $"The {nameof(this.SuspendInstanceAsync)} operation was canceled.", e, cancellation); + } + } + + /// + public override async Task ResumeInstanceAsync( + string instanceId, string? reason = null, CancellationToken cancellation = default) + { + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + + P.ResumeRequest request = new() + { + InstanceId = instanceId, + Reason = reason, + }; + + try + { + await this.sidecarClient.ResumeInstanceAsync(request, cancellationToken: cancellation); + } + catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled) + { + throw new OperationCanceledException( + $"The {nameof(this.ResumeInstanceAsync)} operation was canceled.", e, cancellation); + } + } + + /// + public override async Task GetInstancesAsync( + string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) + { + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + + if (string.IsNullOrEmpty(instanceId)) + { + throw new ArgumentNullException(nameof(instanceId)); + } + + P.GetInstanceResponse response = await this.sidecarClient.GetInstanceAsync( + new P.GetInstanceRequest + { + InstanceId = instanceId, + GetInputsAndOutputs = getInputsAndOutputs, + }, + cancellationToken: cancellation); + + // REVIEW: Should we return a non-null value instead of !exists? + if (!response.Exists) + { + return null; + } + + return this.CreateMetadata(response.OrchestrationState, getInputsAndOutputs); + } + + /// + public override AsyncPageable GetAllInstancesAsync(OrchestrationQuery? filter = null) + { + Check.NotEntity(this.options.EnableEntitySupport, filter?.InstanceIdPrefix); + + return Pageable.Create(async (continuation, pageSize, cancellation) => + { + P.QueryInstancesRequest request = new() + { + Query = new P.InstanceQuery + { + CreatedTimeFrom = filter?.CreatedFrom?.ToTimestamp(), + CreatedTimeTo = filter?.CreatedTo?.ToTimestamp(), + FetchInputsAndOutputs = filter?.FetchInputsAndOutputs ?? false, + InstanceIdPrefix = filter?.InstanceIdPrefix, + MaxInstanceCount = pageSize ?? filter?.PageSize ?? OrchestrationQuery.DefaultPageSize, + ContinuationToken = continuation ?? filter?.ContinuationToken, + }, + }; + + if (filter?.Statuses is not null) + { + request.Query.RuntimeStatus.AddRange(filter.Statuses.Select(x => x.ToGrpcStatus())); + } + + if (filter?.TaskHubNames is not null) + { + request.Query.TaskHubNames.AddRange(filter.TaskHubNames); + } + + try + { + P.QueryInstancesResponse response = await this.sidecarClient.QueryInstancesAsync( + request, cancellationToken: cancellation); + + bool getInputsAndOutputs = filter?.FetchInputsAndOutputs ?? false; + IReadOnlyList values = response.OrchestrationState + .Select(x => this.CreateMetadata(x, getInputsAndOutputs)) + .ToList(); + + return new Page(values, response.ContinuationToken); + } + catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled) + { + throw new OperationCanceledException( + $"The {nameof(this.GetInstancesAsync)} operation was canceled.", e, cancellation); + } + }); + } + + /// + public override async Task WaitForInstanceStartAsync( + string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) + { + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + + this.logger.WaitingForInstanceStart(instanceId, getInputsAndOutputs); + + P.GetInstanceRequest request = new() + { + InstanceId = instanceId, + GetInputsAndOutputs = getInputsAndOutputs, + }; + + try + { + P.GetInstanceResponse response = await this.sidecarClient.WaitForInstanceStartAsync( + request, cancellationToken: cancellation); + return this.CreateMetadata(response.OrchestrationState, getInputsAndOutputs); + } + catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled) + { + throw new OperationCanceledException( + $"The {nameof(this.WaitForInstanceStartAsync)} operation was canceled.", e, cancellation); + } + } + + /// + public override async Task WaitForInstanceCompletionAsync( + string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) + { + Check.NotEntity(this.options.EnableEntitySupport, instanceId); + + this.logger.WaitingForInstanceCompletion(instanceId, getInputsAndOutputs); + + P.GetInstanceRequest request = new() + { + InstanceId = instanceId, + GetInputsAndOutputs = getInputsAndOutputs, + }; + + try + { + P.GetInstanceResponse response = await this.sidecarClient.WaitForInstanceCompletionAsync( + request, cancellationToken: cancellation); + return this.CreateMetadata(response.OrchestrationState, getInputsAndOutputs); + } + catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled) + { + throw new OperationCanceledException( + $"The {nameof(this.WaitForInstanceCompletionAsync)} operation was canceled.", e, cancellation); + } + } + + /// + public override Task PurgeInstanceAsync( + string instanceId, PurgeInstanceOptions? options = null, CancellationToken cancellation = default) + { + bool recursive = options?.Recursive ?? false; + this.logger.PurgingInstanceMetadata(instanceId); + + P.PurgeInstancesRequest request = new() { InstanceId = instanceId, Recursive = recursive }; + return this.PurgeInstancesCoreAsync(request, cancellation); + } + + /// + public override Task PurgeAllInstancesAsync( + PurgeInstancesFilter filter, PurgeInstanceOptions? options = null, CancellationToken cancellation = default) + { + bool recursive = options?.Recursive ?? false; + this.logger.PurgingInstances(filter); + P.PurgeInstancesRequest request = new() + { + PurgeInstanceFilter = new() + { + CreatedTimeFrom = filter?.CreatedFrom.ToTimestamp(), + CreatedTimeTo = filter?.CreatedTo.ToTimestamp(), + }, + Recursive = recursive, + }; + + if (filter?.Statuses is not null) + { + request.PurgeInstanceFilter.RuntimeStatus.AddRange(filter.Statuses.Select(x => x.ToGrpcStatus())); + } + + return this.PurgeInstancesCoreAsync(request, cancellation); + } + + static AsyncDisposable GetCallInvoker(GrpcDurableTaskClientOptions options, out CallInvoker callInvoker) + { + if (options.Channel is GrpcChannel c) + { + callInvoker = c.CreateCallInvoker(); + return default; + } + + if (options.CallInvoker is CallInvoker invoker) + { + callInvoker = invoker; + return default; + } + + c = GetChannel(options.Address); + callInvoker = c.CreateCallInvoker(); + return new AsyncDisposable(() => new(c.ShutdownAsync())); + } + +#if NET6_0_OR_GREATER + static GrpcChannel GetChannel(string? address) + { + if (string.IsNullOrEmpty(address)) + { + address = "http://localhost:4001"; + } + + return GrpcChannel.ForAddress(address); + } +#endif + +#if NETSTANDARD2_0 + static GrpcChannel GetChannel(string? address) + { + if (string.IsNullOrEmpty(address)) + { + address = "localhost:4001"; + } + + return new(address, ChannelCredentials.Insecure); + } +#endif + + async Task PurgeInstancesCoreAsync( + P.PurgeInstancesRequest request, CancellationToken cancellation = default) + { + try + { + P.PurgeInstancesResponse response = await this.sidecarClient.PurgeInstancesAsync( + request, cancellationToken: cancellation); + return new PurgeResult(response.DeletedInstanceCount); + } + catch (RpcException e) when (e.StatusCode == StatusCode.Cancelled) + { + throw new OperationCanceledException( + $"The {nameof(this.PurgeAllInstancesAsync)} operation was canceled.", e, cancellation); + } + } + + OrchestrationMetadata CreateMetadata(P.OrchestrationState state, bool includeInputsAndOutputs) + { + var metadata = new OrchestrationMetadata(state.Name, state.InstanceId) + { + CreatedAt = state.CreatedTimestamp.ToDateTimeOffset(), + LastUpdatedAt = state.LastUpdatedTimestamp.ToDateTimeOffset(), + RuntimeStatus = (OrchestrationRuntimeStatus)state.OrchestrationStatus, + SerializedInput = state.Input, + SerializedOutput = state.Output, + SerializedCustomStatus = state.CustomStatus, + FailureDetails = state.FailureDetails.ToTaskFailureDetails(), + DataConverter = includeInputsAndOutputs ? this.DataConverter : null, + Tags = new Dictionary(state.Tags), + }; + + return metadata; + } +} diff --git a/src/Client/OrchestrationServiceClientShim/ShimDurableTaskClient.cs b/src/Client/OrchestrationServiceClientShim/ShimDurableTaskClient.cs index ce510fd0a..fe7625a88 100644 --- a/src/Client/OrchestrationServiceClientShim/ShimDurableTaskClient.cs +++ b/src/Client/OrchestrationServiceClientShim/ShimDurableTaskClient.cs @@ -1,290 +1,290 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Diagnostics.CodeAnalysis; -using DurableTask.Core; -using DurableTask.Core.Entities; -using DurableTask.Core.History; -using DurableTask.Core.Query; -using Microsoft.DurableTask.Client.Entities; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Options; -using Core = DurableTask.Core; -using CoreOrchestrationQuery = DurableTask.Core.Query.OrchestrationQuery; - -namespace Microsoft.DurableTask.Client.OrchestrationServiceClientShim; - -/// -/// A shim client for interacting with the backend via . -/// -/// -/// Initializes a new instance of the class. -/// -/// The name of the client. -/// The client options. -class ShimDurableTaskClient(string name, ShimDurableTaskClientOptions options) : DurableTaskClient(name) -{ - readonly ShimDurableTaskClientOptions options = Check.NotNull(options); - ShimDurableEntityClient? entities; - - /// - /// Initializes a new instance of the class. - /// - /// The name of this client. - /// The client options. - [ActivatorUtilitiesConstructor] - public ShimDurableTaskClient( - string name, IOptionsMonitor options) - : this(name, Check.NotNull(options).Get(name)) - { - } - - /// - public override DurableEntityClient Entities - { - get - { - if (!this.options.EnableEntitySupport) - { - throw new InvalidOperationException("Entity support is not enabled."); - } - - if (this.entities is null) - { - if (this.options.Entities.Queries is null) - { - throw new NotSupportedException( - "The configured IOrchestrationServiceClient does not support entities."); - } - - this.entities = new(this.Name, this.options); - } - - return this.entities; - } - } - - DataConverter DataConverter => this.options.DataConverter; - - IOrchestrationServiceClient Client => this.options.Client!; - - IOrchestrationServicePurgeClient PurgeClient => this.CastClient(); - - /// - public override ValueTask DisposeAsync() => default; - - /// - public override async Task GetInstancesAsync( - string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) - { - cancellation.ThrowIfCancellationRequested(); - IList states = await this.Client.GetOrchestrationStateAsync(instanceId, false); - if (states is null or { Count: 0 }) - { - return null; - } - - return this.ToMetadata(states.First(), getInputsAndOutputs); - } - - /// - public override AsyncPageable GetAllInstancesAsync(OrchestrationQuery? query = null) - { - // Get this early to force an exception if not supported. - IOrchestrationServiceQueryClient queryClient = this.CastClient(); - return Pageable.Create(async (continuation, pageSize, cancellation) => - { - CoreOrchestrationQuery coreQuery = new() - { - RuntimeStatus = query?.Statuses?.Select(x => x.ConvertToCore()).ToList(), - CreatedTimeFrom = query?.CreatedFrom?.UtcDateTime, - CreatedTimeTo = query?.CreatedTo?.UtcDateTime, - TaskHubNames = query?.TaskHubNames?.ToList(), - PageSize = pageSize ?? query?.PageSize ?? OrchestrationQuery.DefaultPageSize, - ContinuationToken = continuation ?? query?.ContinuationToken, - InstanceIdPrefix = query?.InstanceIdPrefix, - FetchInputsAndOutputs = query?.FetchInputsAndOutputs ?? false, - }; - - OrchestrationQueryResult result = await queryClient.GetOrchestrationWithQueryAsync( - coreQuery, cancellation); - - var metadata = result.OrchestrationState.Select(x => this.ToMetadata(x, coreQuery.FetchInputsAndOutputs)) - .ToList(); - return new Page(metadata, result.ContinuationToken); - }); - } - - /// - public override async Task PurgeInstanceAsync( - string instanceId, PurgeInstanceOptions? options = null, CancellationToken cancellation = default) - { - Check.NotNullOrEmpty(instanceId); - cancellation.ThrowIfCancellationRequested(); - - // TODO: Support recursive purge of sub-orchestrations - Core.PurgeResult result = await this.PurgeClient.PurgeInstanceStateAsync(instanceId); - return result.ConvertFromCore(); - } - - /// - public override async Task PurgeAllInstancesAsync( - PurgeInstancesFilter filter, PurgeInstanceOptions? options = null, CancellationToken cancellation = default) - { - Check.NotNull(filter); - cancellation.ThrowIfCancellationRequested(); - - // TODO: Support recursive purge of sub-orchestrations - Core.PurgeResult result = await this.PurgeClient.PurgeInstanceStateAsync(filter.ConvertToCore()); - return result.ConvertFromCore(); - } - - /// - public override Task RaiseEventAsync( - string instanceId, string eventName, object? eventPayload = null, CancellationToken cancellation = default) - { - Check.NotNullOrEmpty(instanceId); - Check.NotNullOrEmpty(eventName); - - string? serializedInput = this.DataConverter.Serialize(eventPayload); - return this.SendInstanceMessageAsync( - instanceId, new EventRaisedEvent(-1, serializedInput) { Name = eventName }, cancellation); - } - - /// - public override async Task ScheduleNewOrchestrationInstanceAsync( - TaskName orchestratorName, - object? input = null, - StartOrchestrationOptions? options = null, - CancellationToken cancellation = default) - { - cancellation.ThrowIfCancellationRequested(); - string instanceId = options?.InstanceId ?? Guid.NewGuid().ToString("N"); - OrchestrationInstance instance = new() - { - InstanceId = instanceId, - ExecutionId = Guid.NewGuid().ToString("N"), - }; - - string? serializedInput = this.DataConverter.Serialize(input); - TaskMessage message = new() - { - OrchestrationInstance = instance, - Event = new ExecutionStartedEvent(-1, serializedInput) - { - Name = orchestratorName.Name, - Version = orchestratorName.Version, - OrchestrationInstance = instance, - ScheduledStartTime = options?.StartAt?.UtcDateTime, - }, - }; - - await this.Client.CreateTaskOrchestrationAsync(message); - return instanceId; - } - - /// - public override Task SuspendInstanceAsync( - string instanceId, string? reason = null, CancellationToken cancellation = default) - => this.SendInstanceMessageAsync(instanceId, new ExecutionSuspendedEvent(-1, reason), cancellation); - - /// - public override Task ResumeInstanceAsync( - string instanceId, string? reason = null, CancellationToken cancellation = default) - => this.SendInstanceMessageAsync(instanceId, new ExecutionResumedEvent(-1, reason), cancellation); - - /// - public override Task TerminateInstanceAsync( - string instanceId, TerminateInstanceOptions? options = null, CancellationToken cancellation = default) - { - object? output = options?.Output; - Check.NotNullOrEmpty(instanceId); - cancellation.ThrowIfCancellationRequested(); - string? reason = this.DataConverter.Serialize(output); - - // TODO: Support recursive termination of sub-orchestrations - return this.Client.ForceTerminateTaskOrchestrationAsync(instanceId, reason); - } - - /// - public override async Task WaitForInstanceCompletionAsync( - string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) - { - Check.NotNullOrEmpty(instanceId); - OrchestrationState state = await this.Client.WaitForOrchestrationAsync( - instanceId, null, TimeSpan.MaxValue, cancellation); - return this.ToMetadata(state, getInputsAndOutputs); - } - - /// - public override async Task WaitForInstanceStartAsync( - string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) - { - Check.NotNullOrEmpty(instanceId); - - while (true) - { - OrchestrationMetadata? metadata = await this.GetInstancesAsync( - instanceId, getInputsAndOutputs, cancellation); - if (metadata is null) - { - throw new InvalidOperationException($"Orchestration with instanceId '{instanceId}' does not exist"); - } - - if (metadata.RuntimeStatus != OrchestrationRuntimeStatus.Pending) - { - // TODO: Evaluate what to do with "Suspended" state. Do we wait on that? - return metadata; - } - - await Task.Delay(TimeSpan.FromSeconds(1), cancellation); - } - } - - [return: NotNullIfNotNull("state")] - OrchestrationMetadata? ToMetadata(Core.OrchestrationState? state, bool getInputsAndOutputs) - { - if (state is null) - { - return null; - } - - return new OrchestrationMetadata(state.Name, state.OrchestrationInstance.InstanceId) - { - DataConverter = getInputsAndOutputs ? this.DataConverter : null, - RuntimeStatus = state.OrchestrationStatus.ConvertFromCore(), - CreatedAt = state.CreatedTime, - LastUpdatedAt = state.LastUpdatedTime, - SerializedInput = state.Input, - SerializedOutput = state.Output, - SerializedCustomStatus = state.Status, - FailureDetails = state.FailureDetails?.ConvertFromCore(), - }; - } - - T CastClient() - { - if (this.Client is T t) - { - return t; - } - - throw new NotSupportedException($"Provided IOrchestrationServiceClient does not implement {typeof(T)}."); - } - - Task SendInstanceMessageAsync(string instanceId, HistoryEvent @event, CancellationToken cancellation) - { - Check.NotNullOrEmpty(instanceId); - Check.NotNull(@event); - - cancellation.ThrowIfCancellationRequested(); - - TaskMessage message = new() - { - OrchestrationInstance = new() { InstanceId = instanceId }, - Event = @event, - }; - - return this.Client.SendTaskOrchestrationMessageAsync(message); - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics.CodeAnalysis; +using DurableTask.Core; +using DurableTask.Core.History; +using DurableTask.Core.Query; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Core = DurableTask.Core; +using CoreOrchestrationQuery = DurableTask.Core.Query.OrchestrationQuery; + +namespace Microsoft.DurableTask.Client.OrchestrationServiceClientShim; + +/// +/// A shim client for interacting with the backend via . +/// +/// +/// Initializes a new instance of the class. +/// +/// The name of the client. +/// The client options. +class ShimDurableTaskClient(string name, ShimDurableTaskClientOptions options) : DurableTaskClient(name) +{ + readonly ShimDurableTaskClientOptions options = Check.NotNull(options); + ShimDurableEntityClient? entities; + + /// + /// Initializes a new instance of the class. + /// + /// The name of this client. + /// The client options. + [ActivatorUtilitiesConstructor] + public ShimDurableTaskClient( + string name, IOptionsMonitor options) + : this(name, Check.NotNull(options).Get(name)) + { + } + + /// + public override DurableEntityClient Entities + { + get + { + if (!this.options.EnableEntitySupport) + { + throw new InvalidOperationException("Entity support is not enabled."); + } + + if (this.entities is null) + { + if (this.options.Entities.Queries is null) + { + throw new NotSupportedException( + "The configured IOrchestrationServiceClient does not support entities."); + } + + this.entities = new(this.Name, this.options); + } + + return this.entities; + } + } + + DataConverter DataConverter => this.options.DataConverter; + + IOrchestrationServiceClient Client => this.options.Client!; + + IOrchestrationServicePurgeClient PurgeClient => this.CastClient(); + + /// + public override ValueTask DisposeAsync() => default; + + /// + public override async Task GetInstancesAsync( + string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) + { + cancellation.ThrowIfCancellationRequested(); + IList states = await this.Client.GetOrchestrationStateAsync(instanceId, false); + if (states is null or { Count: 0 }) + { + return null; + } + + return this.ToMetadata(states.First(), getInputsAndOutputs); + } + + /// + public override AsyncPageable GetAllInstancesAsync(OrchestrationQuery? query = null) + { + // Get this early to force an exception if not supported. + IOrchestrationServiceQueryClient queryClient = this.CastClient(); + return Pageable.Create(async (continuation, pageSize, cancellation) => + { + CoreOrchestrationQuery coreQuery = new() + { + RuntimeStatus = query?.Statuses?.Select(x => x.ConvertToCore()).ToList(), + CreatedTimeFrom = query?.CreatedFrom?.UtcDateTime, + CreatedTimeTo = query?.CreatedTo?.UtcDateTime, + TaskHubNames = query?.TaskHubNames?.ToList(), + PageSize = pageSize ?? query?.PageSize ?? OrchestrationQuery.DefaultPageSize, + ContinuationToken = continuation ?? query?.ContinuationToken, + InstanceIdPrefix = query?.InstanceIdPrefix, + FetchInputsAndOutputs = query?.FetchInputsAndOutputs ?? false, + }; + + OrchestrationQueryResult result = await queryClient.GetOrchestrationWithQueryAsync( + coreQuery, cancellation); + + var metadata = result.OrchestrationState.Select(x => this.ToMetadata(x, coreQuery.FetchInputsAndOutputs)) + .ToList(); + return new Page(metadata, result.ContinuationToken); + }); + } + + /// + public override async Task PurgeInstanceAsync( + string instanceId, PurgeInstanceOptions? options = null, CancellationToken cancellation = default) + { + Check.NotNullOrEmpty(instanceId); + cancellation.ThrowIfCancellationRequested(); + + // TODO: Support recursive purge of sub-orchestrations + Core.PurgeResult result = await this.PurgeClient.PurgeInstanceStateAsync(instanceId); + return result.ConvertFromCore(); + } + + /// + public override async Task PurgeAllInstancesAsync( + PurgeInstancesFilter filter, PurgeInstanceOptions? options = null, CancellationToken cancellation = default) + { + Check.NotNull(filter); + cancellation.ThrowIfCancellationRequested(); + + // TODO: Support recursive purge of sub-orchestrations + Core.PurgeResult result = await this.PurgeClient.PurgeInstanceStateAsync(filter.ConvertToCore()); + return result.ConvertFromCore(); + } + + /// + public override Task RaiseEventAsync( + string instanceId, string eventName, object? eventPayload = null, CancellationToken cancellation = default) + { + Check.NotNullOrEmpty(instanceId); + Check.NotNullOrEmpty(eventName); + + string? serializedInput = this.DataConverter.Serialize(eventPayload); + return this.SendInstanceMessageAsync( + instanceId, new EventRaisedEvent(-1, serializedInput) { Name = eventName }, cancellation); + } + + /// + public override async Task ScheduleNewOrchestrationInstanceAsync( + TaskName orchestratorName, + object? input = null, + StartOrchestrationOptions? options = null, + CancellationToken cancellation = default) + { + cancellation.ThrowIfCancellationRequested(); + string instanceId = options?.InstanceId ?? Guid.NewGuid().ToString("N"); + OrchestrationInstance instance = new() + { + InstanceId = instanceId, + ExecutionId = Guid.NewGuid().ToString("N"), + }; + + string? serializedInput = this.DataConverter.Serialize(input); + TaskMessage message = new() + { + OrchestrationInstance = instance, + Event = new ExecutionStartedEvent(-1, serializedInput) + { + Name = orchestratorName.Name, + Version = orchestratorName.Version, + OrchestrationInstance = instance, + ScheduledStartTime = options?.StartAt?.UtcDateTime, + Tags = options?.Tags != null ? options.Tags.ToDictionary(kvp => kvp.Key, kvp => kvp.Value) : null, + }, + }; + + await this.Client.CreateTaskOrchestrationAsync(message); + return instanceId; + } + + /// + public override Task SuspendInstanceAsync( + string instanceId, string? reason = null, CancellationToken cancellation = default) + => this.SendInstanceMessageAsync(instanceId, new ExecutionSuspendedEvent(-1, reason), cancellation); + + /// + public override Task ResumeInstanceAsync( + string instanceId, string? reason = null, CancellationToken cancellation = default) + => this.SendInstanceMessageAsync(instanceId, new ExecutionResumedEvent(-1, reason), cancellation); + + /// + public override Task TerminateInstanceAsync( + string instanceId, TerminateInstanceOptions? options = null, CancellationToken cancellation = default) + { + object? output = options?.Output; + Check.NotNullOrEmpty(instanceId); + cancellation.ThrowIfCancellationRequested(); + string? reason = this.DataConverter.Serialize(output); + + // TODO: Support recursive termination of sub-orchestrations + return this.Client.ForceTerminateTaskOrchestrationAsync(instanceId, reason); + } + + /// + public override async Task WaitForInstanceCompletionAsync( + string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) + { + Check.NotNullOrEmpty(instanceId); + OrchestrationState state = await this.Client.WaitForOrchestrationAsync( + instanceId, null, TimeSpan.MaxValue, cancellation); + return this.ToMetadata(state, getInputsAndOutputs); + } + + /// + public override async Task WaitForInstanceStartAsync( + string instanceId, bool getInputsAndOutputs = false, CancellationToken cancellation = default) + { + Check.NotNullOrEmpty(instanceId); + + while (true) + { + OrchestrationMetadata? metadata = await this.GetInstancesAsync( + instanceId, getInputsAndOutputs, cancellation); + if (metadata is null) + { + throw new InvalidOperationException($"Orchestration with instanceId '{instanceId}' does not exist"); + } + + if (metadata.RuntimeStatus != OrchestrationRuntimeStatus.Pending) + { + // TODO: Evaluate what to do with "Suspended" state. Do we wait on that? + return metadata; + } + + await Task.Delay(TimeSpan.FromSeconds(1), cancellation); + } + } + + [return: NotNullIfNotNull("state")] + OrchestrationMetadata? ToMetadata(Core.OrchestrationState? state, bool getInputsAndOutputs) + { + if (state is null) + { + return null; + } + + return new OrchestrationMetadata(state.Name, state.OrchestrationInstance.InstanceId) + { + DataConverter = getInputsAndOutputs ? this.DataConverter : null, + RuntimeStatus = state.OrchestrationStatus.ConvertFromCore(), + CreatedAt = state.CreatedTime, + LastUpdatedAt = state.LastUpdatedTime, + SerializedInput = state.Input, + SerializedOutput = state.Output, + SerializedCustomStatus = state.Status, + FailureDetails = state.FailureDetails?.ConvertFromCore(), + }; + } + + T CastClient() + { + if (this.Client is T t) + { + return t; + } + + throw new NotSupportedException($"Provided IOrchestrationServiceClient does not implement {typeof(T)}."); + } + + Task SendInstanceMessageAsync(string instanceId, HistoryEvent @event, CancellationToken cancellation) + { + Check.NotNullOrEmpty(instanceId); + Check.NotNull(@event); + + cancellation.ThrowIfCancellationRequested(); + + TaskMessage message = new() + { + OrchestrationInstance = new() { InstanceId = instanceId }, + Event = @event, + }; + + return this.Client.SendTaskOrchestrationMessageAsync(message); + } +} diff --git a/src/Grpc/orchestrator_service.proto b/src/Grpc/orchestrator_service.proto index 0fa6b6595..64e752818 100644 --- a/src/Grpc/orchestrator_service.proto +++ b/src/Grpc/orchestrator_service.proto @@ -75,6 +75,7 @@ message ExecutionStartedEvent { google.protobuf.Timestamp scheduledStartTimestamp = 6; TraceContext parentTraceContext = 7; google.protobuf.StringValue orchestrationSpanID = 8; + map tags = 9; } message ExecutionCompletedEvent { @@ -343,14 +344,8 @@ message CreateInstanceRequest { } message OrchestrationIdReusePolicy { - repeated OrchestrationStatus operationStatus = 1; - CreateOrchestrationAction action = 2; -} - -enum CreateOrchestrationAction { - ERROR = 0; - IGNORE = 1; - TERMINATE = 2; + repeated OrchestrationStatus replaceableStatus = 1; + reserved 2; } message CreateInstanceResponse { @@ -391,6 +386,7 @@ message OrchestrationState { google.protobuf.StringValue executionId = 12; google.protobuf.Timestamp completedTimestamp = 13; google.protobuf.StringValue parentInstanceId = 14; + map tags = 15; } message RaiseEventRequest { @@ -731,4 +727,4 @@ message StreamInstanceHistoryRequest { message HistoryChunk { repeated HistoryEvent events = 1; -} \ No newline at end of file +} diff --git a/src/Grpc/versions.txt b/src/Grpc/versions.txt index ca514f29a..c016c8199 100644 --- a/src/Grpc/versions.txt +++ b/src/Grpc/versions.txt @@ -1,2 +1,2 @@ -# The following files were downloaded from branch main at 2025-02-19 06:25:02 UTC -https://raw.githubusercontent.com/microsoft/durabletask-protobuf/589cb5ecd9dd4b1fe463750defa3e2c84276b079/protos/orchestrator_service.proto +# The following files were downloaded from branch main at 2025-03-19 19:55:31 UTC +https://raw.githubusercontent.com/microsoft/durabletask-protobuf/4792f47019ab2b3e9ea979fb4af72427a4144c51/protos/orchestrator_service.proto diff --git a/src/Shared/Grpc/ProtoUtils.cs b/src/Shared/Grpc/ProtoUtils.cs index c3c0e45f7..9a18c524c 100644 --- a/src/Shared/Grpc/ProtoUtils.cs +++ b/src/Shared/Grpc/ProtoUtils.cs @@ -1,1042 +1,1044 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Buffers; -using System.Buffers.Text; -using System.Diagnostics.CodeAnalysis; -using System.Runtime.CompilerServices; -using System.Text; -using DurableTask.Core; -using DurableTask.Core.Command; -using DurableTask.Core.Entities; -using DurableTask.Core.Entities.OperationFormat; -using DurableTask.Core.History; -using Google.Protobuf; -using Google.Protobuf.WellKnownTypes; -using DTCore = DurableTask.Core; -using P = Microsoft.DurableTask.Protobuf; - -namespace Microsoft.DurableTask; - -/// -/// Protobuf utilities and helpers. -/// -static class ProtoUtils -{ - /// - /// Converts a history event from to . - /// - /// The proto history event to converter. - /// The converted history event. - /// When the provided history event type is not supported. - internal static HistoryEvent ConvertHistoryEvent(P.HistoryEvent proto) - { - return ConvertHistoryEvent(proto, conversionState: null); - } - - /// - /// Converts a history event from to , and performs - /// stateful conversions of entity-related events. - /// - /// The proto history event to converter. - /// State needed for converting entity-related history entries and actions. - /// The converted history event. - /// When the provided history event type is not supported. - internal static HistoryEvent ConvertHistoryEvent(P.HistoryEvent proto, EntityConversionState? conversionState) - { - Check.NotNull(proto); - HistoryEvent historyEvent; - switch (proto.EventTypeCase) - { - case P.HistoryEvent.EventTypeOneofCase.ContinueAsNew: - historyEvent = new ContinueAsNewEvent(proto.EventId, proto.ContinueAsNew.Input); - break; - case P.HistoryEvent.EventTypeOneofCase.ExecutionStarted: - OrchestrationInstance instance = proto.ExecutionStarted.OrchestrationInstance.ToCore(); - conversionState?.SetOrchestrationInstance(instance); - historyEvent = new ExecutionStartedEvent(proto.EventId, proto.ExecutionStarted.Input) - { - Name = proto.ExecutionStarted.Name, - Version = proto.ExecutionStarted.Version, - OrchestrationInstance = instance, - ParentInstance = proto.ExecutionStarted.ParentInstance == null ? null : new ParentInstance - { - Name = proto.ExecutionStarted.ParentInstance.Name, - Version = proto.ExecutionStarted.ParentInstance.Version, - OrchestrationInstance = proto.ExecutionStarted.ParentInstance.OrchestrationInstance.ToCore(), - TaskScheduleId = proto.ExecutionStarted.ParentInstance.TaskScheduledId, - }, - ScheduledStartTime = proto.ExecutionStarted.ScheduledStartTimestamp?.ToDateTime(), - }; - break; - case P.HistoryEvent.EventTypeOneofCase.ExecutionCompleted: - historyEvent = new ExecutionCompletedEvent( - proto.EventId, - proto.ExecutionCompleted.Result, - proto.ExecutionCompleted.OrchestrationStatus.ToCore()); - break; - case P.HistoryEvent.EventTypeOneofCase.ExecutionTerminated: - historyEvent = new ExecutionTerminatedEvent(proto.EventId, proto.ExecutionTerminated.Input); - break; - case P.HistoryEvent.EventTypeOneofCase.ExecutionSuspended: - historyEvent = new ExecutionSuspendedEvent(proto.EventId, proto.ExecutionSuspended.Input); - break; - case P.HistoryEvent.EventTypeOneofCase.ExecutionResumed: - historyEvent = new ExecutionResumedEvent(proto.EventId, proto.ExecutionResumed.Input); - break; - case P.HistoryEvent.EventTypeOneofCase.TaskScheduled: - historyEvent = new TaskScheduledEvent( - proto.EventId, - proto.TaskScheduled.Name, - proto.TaskScheduled.Version, - proto.TaskScheduled.Input); - break; - case P.HistoryEvent.EventTypeOneofCase.TaskCompleted: - historyEvent = new TaskCompletedEvent( - proto.EventId, - proto.TaskCompleted.TaskScheduledId, - proto.TaskCompleted.Result); - break; - case P.HistoryEvent.EventTypeOneofCase.TaskFailed: - historyEvent = new TaskFailedEvent( - proto.EventId, - proto.TaskFailed.TaskScheduledId, - reason: null, /* not supported */ - details: null, /* not supported */ - proto.TaskFailed.FailureDetails.ToCore()); - break; - case P.HistoryEvent.EventTypeOneofCase.SubOrchestrationInstanceCreated: - historyEvent = new SubOrchestrationInstanceCreatedEvent(proto.EventId) - { - Input = proto.SubOrchestrationInstanceCreated.Input, - InstanceId = proto.SubOrchestrationInstanceCreated.InstanceId, - Name = proto.SubOrchestrationInstanceCreated.Name, - Version = proto.SubOrchestrationInstanceCreated.Version, - }; - break; - case P.HistoryEvent.EventTypeOneofCase.SubOrchestrationInstanceCompleted: - historyEvent = new SubOrchestrationInstanceCompletedEvent( - proto.EventId, - proto.SubOrchestrationInstanceCompleted.TaskScheduledId, - proto.SubOrchestrationInstanceCompleted.Result); - break; - case P.HistoryEvent.EventTypeOneofCase.SubOrchestrationInstanceFailed: - historyEvent = new SubOrchestrationInstanceFailedEvent( - proto.EventId, - proto.SubOrchestrationInstanceFailed.TaskScheduledId, - reason: null /* not supported */, - details: null /* not supported */, - proto.SubOrchestrationInstanceFailed.FailureDetails.ToCore()); - break; - case P.HistoryEvent.EventTypeOneofCase.TimerCreated: - historyEvent = new TimerCreatedEvent( - proto.EventId, - proto.TimerCreated.FireAt.ToDateTime()); - break; - case P.HistoryEvent.EventTypeOneofCase.TimerFired: - historyEvent = new TimerFiredEvent( - eventId: -1, - proto.TimerFired.FireAt.ToDateTime()) - { - TimerId = proto.TimerFired.TimerId, - }; - break; - case P.HistoryEvent.EventTypeOneofCase.OrchestratorStarted: - historyEvent = new OrchestratorStartedEvent(proto.EventId); - break; - case P.HistoryEvent.EventTypeOneofCase.OrchestratorCompleted: - historyEvent = new OrchestratorCompletedEvent(proto.EventId); - break; - case P.HistoryEvent.EventTypeOneofCase.EventSent: - historyEvent = new EventSentEvent(proto.EventId) - { - InstanceId = proto.EventSent.InstanceId, - Name = proto.EventSent.Name, - Input = proto.EventSent.Input, - }; - break; - case P.HistoryEvent.EventTypeOneofCase.EventRaised: - historyEvent = new EventRaisedEvent(proto.EventId, proto.EventRaised.Input) - { - Name = proto.EventRaised.Name, - }; - break; - case P.HistoryEvent.EventTypeOneofCase.EntityOperationCalled: - historyEvent = EntityConversions.EncodeOperationCalled(proto, conversionState!.CurrentInstance); - conversionState?.EntityRequestIds.Add(proto.EntityOperationCalled.RequestId); - break; - case P.HistoryEvent.EventTypeOneofCase.EntityOperationSignaled: - historyEvent = EntityConversions.EncodeOperationSignaled(proto); - conversionState?.EntityRequestIds.Add(proto.EntityOperationSignaled.RequestId); - break; - case P.HistoryEvent.EventTypeOneofCase.EntityLockRequested: - historyEvent = EntityConversions.EncodeLockRequested(proto, conversionState!.CurrentInstance); - conversionState?.AddUnlockObligations(proto.EntityLockRequested); - break; - case P.HistoryEvent.EventTypeOneofCase.EntityUnlockSent: - historyEvent = EntityConversions.EncodeUnlockSent(proto, conversionState!.CurrentInstance); - conversionState?.RemoveUnlockObligation(proto.EntityUnlockSent.TargetInstanceId); - break; - case P.HistoryEvent.EventTypeOneofCase.EntityLockGranted: - historyEvent = EntityConversions.EncodeLockGranted(proto); - break; - case P.HistoryEvent.EventTypeOneofCase.EntityOperationCompleted: - historyEvent = EntityConversions.EncodeOperationCompleted(proto); - break; - case P.HistoryEvent.EventTypeOneofCase.EntityOperationFailed: - historyEvent = EntityConversions.EncodeOperationFailed(proto); - break; - case P.HistoryEvent.EventTypeOneofCase.GenericEvent: - historyEvent = new GenericEvent(proto.EventId, proto.GenericEvent.Data); - break; - case P.HistoryEvent.EventTypeOneofCase.HistoryState: - historyEvent = new HistoryStateEvent( - proto.EventId, - new OrchestrationState - { - OrchestrationInstance = new OrchestrationInstance - { - InstanceId = proto.HistoryState.OrchestrationState.InstanceId, - }, - Name = proto.HistoryState.OrchestrationState.Name, - Version = proto.HistoryState.OrchestrationState.Version, - ScheduledStartTime = proto.HistoryState.OrchestrationState.ScheduledStartTimestamp.ToDateTime(), - CreatedTime = proto.HistoryState.OrchestrationState.CreatedTimestamp.ToDateTime(), - LastUpdatedTime = proto.HistoryState.OrchestrationState.LastUpdatedTimestamp.ToDateTime(), - Input = proto.HistoryState.OrchestrationState.Input, - Output = proto.HistoryState.OrchestrationState.Output, - Status = proto.HistoryState.OrchestrationState.CustomStatus, - }); - break; - default: - throw new NotSupportedException($"Deserialization of {proto.EventTypeCase} is not supported."); - } - - historyEvent.Timestamp = proto.Timestamp.ToDateTime(); - return historyEvent; - } - - /// - /// Converts a to a gRPC . - /// - /// The date-time to convert. - /// The gRPC timestamp. - internal static Timestamp ToTimestamp(this DateTime dateTime) - { - // The protobuf libraries require timestamps to be in UTC - if (dateTime.Kind == DateTimeKind.Unspecified) - { - dateTime = DateTime.SpecifyKind(dateTime, DateTimeKind.Utc); - } - else if (dateTime.Kind == DateTimeKind.Local) - { - dateTime = dateTime.ToUniversalTime(); - } - - return Timestamp.FromDateTime(dateTime); - } - - /// - /// Converts a to a gRPC . - /// - /// The date-time to convert. - /// The gRPC timestamp. - internal static Timestamp? ToTimestamp(this DateTime? dateTime) - => dateTime.HasValue ? dateTime.Value.ToTimestamp() : null; - - /// - /// Converts a to a gRPC . - /// - /// The date-time to convert. - /// The gRPC timestamp. - internal static Timestamp ToTimestamp(this DateTimeOffset dateTime) => Timestamp.FromDateTimeOffset(dateTime); - - /// - /// Converts a to a gRPC . - /// - /// The date-time to convert. - /// The gRPC timestamp. - internal static Timestamp? ToTimestamp(this DateTimeOffset? dateTime) - => dateTime.HasValue ? dateTime.Value.ToTimestamp() : null; - - /// - /// Constructs a . - /// - /// The orchestrator instance ID. - /// The orchestrator customer status or null if no custom status. - /// The orchestrator actions. - /// - /// The completion token for the work item. It must be the exact same - /// value that was provided by the corresponding that triggered the orchestrator execution. - /// - /// The entity conversion state, or null if no conversion is required. - /// The orchestrator response. - /// When an orchestrator action is unknown. - internal static P.OrchestratorResponse ConstructOrchestratorResponse( - string instanceId, - string? customStatus, - IEnumerable actions, - string completionToken, - EntityConversionState? entityConversionState) - { - Check.NotNull(actions); - var response = new P.OrchestratorResponse - { - InstanceId = instanceId, - CustomStatus = customStatus, - CompletionToken = completionToken, - }; - - foreach (OrchestratorAction action in actions) - { - var protoAction = new P.OrchestratorAction { Id = action.Id }; - - switch (action.OrchestratorActionType) - { - case OrchestratorActionType.ScheduleOrchestrator: - var scheduleTaskAction = (ScheduleTaskOrchestratorAction)action; - protoAction.ScheduleTask = new P.ScheduleTaskAction - { - Name = scheduleTaskAction.Name, - Version = scheduleTaskAction.Version, - Input = scheduleTaskAction.Input, - }; - break; - case OrchestratorActionType.CreateSubOrchestration: - var subOrchestrationAction = (CreateSubOrchestrationAction)action; - protoAction.CreateSubOrchestration = new P.CreateSubOrchestrationAction - { - Input = subOrchestrationAction.Input, - InstanceId = subOrchestrationAction.InstanceId, - Name = subOrchestrationAction.Name, - Version = subOrchestrationAction.Version, - }; - break; - case OrchestratorActionType.CreateTimer: - var createTimerAction = (CreateTimerOrchestratorAction)action; - protoAction.CreateTimer = new P.CreateTimerAction - { - FireAt = createTimerAction.FireAt.ToTimestamp(), - }; - break; - case OrchestratorActionType.SendEvent: - var sendEventAction = (SendEventOrchestratorAction)action; - if (sendEventAction.Instance == null) - { - throw new ArgumentException( - $"{nameof(SendEventOrchestratorAction)} cannot have a null Instance property!"); - } - - if (entityConversionState is not null - && DTCore.Common.Entities.IsEntityInstance(sendEventAction.Instance.InstanceId) - && sendEventAction.EventName is not null - && sendEventAction.EventData is not null) - { - P.SendEntityMessageAction sendAction = new P.SendEntityMessageAction(); - protoAction.SendEntityMessage = sendAction; - - EntityConversions.DecodeEntityMessageAction( - sendEventAction.EventName, - sendEventAction.EventData, - sendEventAction.Instance.InstanceId, - sendAction, - out string requestId); - - entityConversionState.EntityRequestIds.Add(requestId); - - switch (sendAction.EntityMessageTypeCase) - { - case P.SendEntityMessageAction.EntityMessageTypeOneofCase.EntityLockRequested: - entityConversionState.AddUnlockObligations(sendAction.EntityLockRequested); - break; - case P.SendEntityMessageAction.EntityMessageTypeOneofCase.EntityUnlockSent: - entityConversionState.RemoveUnlockObligation(sendAction.EntityUnlockSent.TargetInstanceId); - break; - default: - break; - } - } - else - { - protoAction.SendEvent = new P.SendEventAction - { - Instance = sendEventAction.Instance.ToProtobuf(), - Name = sendEventAction.EventName, - Data = sendEventAction.EventData, - }; - } - - break; - case OrchestratorActionType.OrchestrationComplete: - - if (entityConversionState is not null) - { - // as a precaution, unlock any entities that were not unlocked for some reason, before - // completing the orchestration. - foreach ((string target, string criticalSectionId) in entityConversionState.ResetObligations()) - { - response.Actions.Add(new P.OrchestratorAction - { - Id = action.Id, - SendEntityMessage = new P.SendEntityMessageAction - { - EntityUnlockSent = new P.EntityUnlockSentEvent - { - CriticalSectionId = criticalSectionId, - TargetInstanceId = target, - ParentInstanceId = entityConversionState.CurrentInstance?.InstanceId, - }, - }, - }); - } - } - - var completeAction = (OrchestrationCompleteOrchestratorAction)action; - protoAction.CompleteOrchestration = new P.CompleteOrchestrationAction - { - CarryoverEvents = - { - // TODO - }, - Details = completeAction.Details, - NewVersion = completeAction.NewVersion, - OrchestrationStatus = completeAction.OrchestrationStatus.ToProtobuf(), - Result = completeAction.Result, - }; - - if (completeAction.OrchestrationStatus == OrchestrationStatus.Failed) - { - protoAction.CompleteOrchestration.FailureDetails = completeAction.FailureDetails.ToProtobuf(); - } - - break; - default: - throw new NotSupportedException($"Unknown orchestrator action: {action.OrchestratorActionType}"); - } - - response.Actions.Add(protoAction); - } - - return response; - } - - /// - /// Converts a to a . - /// - /// The status to convert. - /// The converted status. - internal static OrchestrationStatus ToCore(this P.OrchestrationStatus status) - { - return (OrchestrationStatus)status; - } - - /// - /// Converts a to a . - /// - /// The status to convert. - /// The converted status. - [return: NotNullIfNotNull(nameof(status))] - internal static OrchestrationInstance? ToCore(this P.OrchestrationInstance? status) - { - if (status == null) - { - return null; - } - - return new OrchestrationInstance - { - InstanceId = status.InstanceId, - ExecutionId = status.ExecutionId, - }; - } - - /// - /// Converts a to a . - /// - /// The failure details to convert. - /// The converted failure details. - [return: NotNullIfNotNull(nameof(failureDetails))] - internal static TaskFailureDetails? ToTaskFailureDetails(this P.TaskFailureDetails? failureDetails) - { - if (failureDetails == null) - { - return null; - } - - return new TaskFailureDetails( - failureDetails.ErrorType, - failureDetails.ErrorMessage, - failureDetails.StackTrace, - failureDetails.InnerFailure.ToTaskFailureDetails()); - } - - /// - /// Converts a to . - /// - /// The exception to convert. - /// The task failure details. - [return: NotNullIfNotNull(nameof(e))] - internal static P.TaskFailureDetails? ToTaskFailureDetails(this Exception? e) - { - if (e == null) - { - return null; - } - - return new P.TaskFailureDetails - { - ErrorType = e.GetType().FullName, - ErrorMessage = e.Message, - StackTrace = e.StackTrace, - InnerFailure = e.InnerException.ToTaskFailureDetails(), - }; - } - - /// - /// Converts a to a . - /// - /// The entity batch request to convert. - /// The converted entity batch request. - [return: NotNullIfNotNull(nameof(entityBatchRequest))] - internal static EntityBatchRequest? ToEntityBatchRequest(this P.EntityBatchRequest? entityBatchRequest) - { - if (entityBatchRequest == null) - { - return null; - } - - return new EntityBatchRequest() - { - EntityState = entityBatchRequest.EntityState, - InstanceId = entityBatchRequest.InstanceId, - Operations = entityBatchRequest.Operations.Select(r => r.ToOperationRequest()).ToList(), - }; - } - - /// - /// Converts a to a . - /// - /// The entity request to convert. - /// The converted request. - /// Additional info about each operation, required by DTS. - internal static void ToEntityBatchRequest( - this P.EntityRequest entityRequest, - out EntityBatchRequest batchRequest, - out List operationInfos) - { - batchRequest = new EntityBatchRequest() - { - EntityState = entityRequest.EntityState, - InstanceId = entityRequest.InstanceId, - Operations = [], // operations are added to this collection below - }; - - operationInfos = new(entityRequest.OperationRequests.Count); - - foreach (P.HistoryEvent? op in entityRequest.OperationRequests) - { - if (op.EntityOperationSignaled is not null) - { - batchRequest.Operations.Add(new OperationRequest - { - Id = Guid.Parse(op.EntityOperationSignaled.RequestId), - Operation = op.EntityOperationSignaled.Operation, - Input = op.EntityOperationSignaled.Input, - }); - operationInfos.Add(new P.OperationInfo - { - RequestId = op.EntityOperationSignaled.RequestId, - ResponseDestination = null, // means we don't send back a response to the caller - }); - } - else if (op.EntityOperationCalled is not null) - { - batchRequest.Operations.Add(new OperationRequest - { - Id = Guid.Parse(op.EntityOperationCalled.RequestId), - Operation = op.EntityOperationCalled.Operation, - Input = op.EntityOperationCalled.Input, - }); - operationInfos.Add(new P.OperationInfo - { - RequestId = op.EntityOperationCalled.RequestId, - ResponseDestination = new P.OrchestrationInstance - { - InstanceId = op.EntityOperationCalled.ParentInstanceId, - ExecutionId = op.EntityOperationCalled.ParentExecutionId, - }, - }); - } - } - } - - /// - /// Converts a to a . - /// - /// The operation request to convert. - /// The converted operation request. - [return: NotNullIfNotNull(nameof(operationRequest))] - internal static OperationRequest? ToOperationRequest(this P.OperationRequest? operationRequest) - { - if (operationRequest == null) - { - return null; - } - - return new OperationRequest() - { - Operation = operationRequest.Operation, - Input = operationRequest.Input, - Id = Guid.Parse(operationRequest.RequestId), - }; - } - - /// - /// Converts a to a . - /// - /// The operation result to convert. - /// The converted operation result. - [return: NotNullIfNotNull(nameof(operationResult))] - internal static OperationResult? ToOperationResult(this P.OperationResult? operationResult) - { - if (operationResult == null) - { - return null; - } - - switch (operationResult.ResultTypeCase) - { - case P.OperationResult.ResultTypeOneofCase.Success: - return new OperationResult() - { - Result = operationResult.Success.Result, - }; - - case P.OperationResult.ResultTypeOneofCase.Failure: - return new OperationResult() - { - FailureDetails = operationResult.Failure.FailureDetails.ToCore(), - }; - - default: - throw new NotSupportedException($"Deserialization of {operationResult.ResultTypeCase} is not supported."); - } - } - - /// - /// Converts a to . - /// - /// The operation result to convert. - /// The converted operation result. - [return: NotNullIfNotNull(nameof(operationResult))] - internal static P.OperationResult? ToOperationResult(this OperationResult? operationResult) - { - if (operationResult == null) - { - return null; - } - - if (operationResult.FailureDetails == null) - { - return new P.OperationResult() - { - Success = new P.OperationResultSuccess() - { - Result = operationResult.Result, - }, - }; - } - else - { - return new P.OperationResult() - { - Failure = new P.OperationResultFailure() - { - FailureDetails = ToProtobuf(operationResult.FailureDetails), - }, - }; - } - } - - /// - /// Converts a to a . - /// - /// The operation action to convert. - /// The converted operation action. - [return: NotNullIfNotNull(nameof(operationAction))] - internal static OperationAction? ToOperationAction(this P.OperationAction? operationAction) - { - if (operationAction == null) - { - return null; - } - - switch (operationAction.OperationActionTypeCase) - { - case P.OperationAction.OperationActionTypeOneofCase.SendSignal: - - return new SendSignalOperationAction() - { - Name = operationAction.SendSignal.Name, - Input = operationAction.SendSignal.Input, - InstanceId = operationAction.SendSignal.InstanceId, - ScheduledTime = operationAction.SendSignal.ScheduledTime?.ToDateTime(), - }; - - case P.OperationAction.OperationActionTypeOneofCase.StartNewOrchestration: - - return new StartNewOrchestrationOperationAction() - { - Name = operationAction.StartNewOrchestration.Name, - Input = operationAction.StartNewOrchestration.Input, - InstanceId = operationAction.StartNewOrchestration.InstanceId, - Version = operationAction.StartNewOrchestration.Version, - ScheduledStartTime = operationAction.StartNewOrchestration.ScheduledTime?.ToDateTime(), - }; - default: - throw new NotSupportedException($"Deserialization of {operationAction.OperationActionTypeCase} is not supported."); - } - } - - /// - /// Converts a to . - /// - /// The operation action to convert. - /// The converted operation action. - [return: NotNullIfNotNull(nameof(operationAction))] - internal static P.OperationAction? ToOperationAction(this OperationAction? operationAction) - { - if (operationAction == null) - { - return null; - } - - var action = new P.OperationAction(); - - switch (operationAction) - { - case SendSignalOperationAction sendSignalAction: - - action.SendSignal = new P.SendSignalAction() - { - Name = sendSignalAction.Name, - Input = sendSignalAction.Input, - InstanceId = sendSignalAction.InstanceId, - ScheduledTime = sendSignalAction.ScheduledTime?.ToTimestamp(), - }; - break; - - case StartNewOrchestrationOperationAction startNewOrchestrationAction: - - action.StartNewOrchestration = new P.StartNewOrchestrationAction() - { - Name = startNewOrchestrationAction.Name, - Input = startNewOrchestrationAction.Input, - Version = startNewOrchestrationAction.Version, - InstanceId = startNewOrchestrationAction.InstanceId, - ScheduledTime = startNewOrchestrationAction.ScheduledStartTime?.ToTimestamp(), - }; - break; - } - - return action; - } - - /// - /// Converts a to a . - /// - /// The operation result to convert. - /// The converted operation result. - [return: NotNullIfNotNull(nameof(entityBatchResult))] - internal static EntityBatchResult? ToEntityBatchResult(this P.EntityBatchResult? entityBatchResult) - { - if (entityBatchResult == null) - { - return null; - } - - return new EntityBatchResult() - { - Actions = entityBatchResult.Actions.Select(operationAction => operationAction!.ToOperationAction()).ToList(), - EntityState = entityBatchResult.EntityState, - Results = entityBatchResult.Results.Select(operationResult => operationResult!.ToOperationResult()).ToList(), - FailureDetails = entityBatchResult.FailureDetails.ToCore(), - }; - } - - /// - /// Converts a to . - /// - /// The operation result to convert. - /// The completion token, or null for the older protocol. - /// Additional information about each operation, required by DTS. - /// The converted operation result. - [return: NotNullIfNotNull(nameof(entityBatchResult))] - internal static P.EntityBatchResult? ToEntityBatchResult( - this EntityBatchResult? entityBatchResult, - string? completionToken = null, - IEnumerable? operationInfos = null) - { - if (entityBatchResult == null) - { - return null; - } - - return new P.EntityBatchResult() - { - EntityState = entityBatchResult.EntityState, - FailureDetails = entityBatchResult.FailureDetails.ToProtobuf(), - Actions = { entityBatchResult.Actions?.Select(a => a.ToOperationAction()) ?? [] }, - Results = { entityBatchResult.Results?.Select(a => a.ToOperationResult()) ?? [] }, - CompletionToken = completionToken ?? string.Empty, - OperationInfos = { operationInfos ?? [] }, - }; - } - - /// - /// Converts the gRPC representation of orchestrator entity parameters to the DT.Core representation. - /// - /// The DT.Core representation. - /// The gRPC representation. - [return: NotNullIfNotNull(nameof(parameters))] - internal static TaskOrchestrationEntityParameters? ToCore(this P.OrchestratorEntityParameters? parameters) - { - if (parameters == null) - { - return null; - } - - return new TaskOrchestrationEntityParameters() - { - EntityMessageReorderWindow = parameters.EntityMessageReorderWindow.ToTimeSpan(), - }; - } - - /// - /// Gets the approximate byte count for a . - /// - /// The failure details. - /// The approximate byte count. - internal static int GetApproximateByteCount(this P.TaskFailureDetails failureDetails) - { - // Protobuf strings are always UTF-8: https://developers.google.com/protocol-buffers/docs/proto3#scalar - Encoding encoding = Encoding.UTF8; - - int byteCount = 0; - if (failureDetails.ErrorType != null) - { - byteCount += encoding.GetByteCount(failureDetails.ErrorType); - } - - if (failureDetails.ErrorMessage != null) - { - byteCount += encoding.GetByteCount(failureDetails.ErrorMessage); - } - - if (failureDetails.StackTrace != null) - { - byteCount += encoding.GetByteCount(failureDetails.StackTrace); - } - - if (failureDetails.InnerFailure != null) - { - byteCount += failureDetails.InnerFailure.GetApproximateByteCount(); - } - - return byteCount; - } - - /// - /// Decode a protobuf message from a base64 string. - /// - /// The type to decode to. - /// The message parser. - /// The base64 encoded message. - /// The decoded message. - /// If decoding fails. - internal static T Base64Decode(this MessageParser parser, string encodedMessage) where T : IMessage - { - // Decode the base64 in a way that doesn't allocate a byte[] on each request - int encodedByteCount = Encoding.UTF8.GetByteCount(encodedMessage); - byte[] buffer = ArrayPool.Shared.Rent(encodedByteCount); - try - { - // The Base64 APIs require first converting the string into UTF-8 bytes. We then - // do an in-place conversion from base64 UTF-8 bytes to protobuf bytes so that - // we can finally decode the protobuf request. - Encoding.UTF8.GetBytes(encodedMessage, 0, encodedMessage.Length, buffer, 0); - OperationStatus status = Base64.DecodeFromUtf8InPlace( - buffer.AsSpan(0, encodedByteCount), - out int bytesWritten); - if (status != OperationStatus.Done) - { - throw new ArgumentException( - $"Failed to base64-decode the '{typeof(T).Name}' payload: {status}", nameof(encodedMessage)); - } - - return (T)parser.ParseFrom(buffer, 0, bytesWritten); - } - finally - { - ArrayPool.Shared.Return(buffer); - } - } - - /// - /// Converts a grpc to a . - /// - /// The failure details to convert. - /// The converted failure details. - internal static FailureDetails? ToCore(this P.TaskFailureDetails? failureDetails) - { - if (failureDetails == null) - { - return null; - } - - return new FailureDetails( - failureDetails.ErrorType, - failureDetails.ErrorMessage, - failureDetails.StackTrace, - failureDetails.InnerFailure.ToCore(), - failureDetails.IsNonRetriable); - } - - /// - /// Converts a to a grpc . - /// - /// The failure details to convert. - /// The converted failure details. - static P.TaskFailureDetails? ToProtobuf(this FailureDetails? failureDetails) - { - if (failureDetails == null) - { - return null; - } - - return new P.TaskFailureDetails - { - ErrorType = failureDetails.ErrorType ?? "(unknown)", - ErrorMessage = failureDetails.ErrorMessage ?? "(unknown)", - StackTrace = failureDetails.StackTrace, - IsNonRetriable = failureDetails.IsNonRetriable, - InnerFailure = failureDetails.InnerFailure.ToProtobuf(), - }; - } - - static P.OrchestrationStatus ToProtobuf(this OrchestrationStatus status) - { - return (P.OrchestrationStatus)status; - } - - static P.OrchestrationInstance ToProtobuf(this OrchestrationInstance instance) - { - return new P.OrchestrationInstance - { - InstanceId = instance.InstanceId, - ExecutionId = instance.ExecutionId, - }; - } - - /// - /// Tracks state required for converting orchestration histories containing entity-related events. - /// - internal class EntityConversionState - { - readonly bool insertMissingEntityUnlocks; - - OrchestrationInstance? instance; - HashSet? entityRequestIds; - Dictionary? unlockObligations; - - /// - /// Initializes a new instance of the class. - /// - /// Whether to insert missing unlock events in to the history - /// when the orchestration completes. - public EntityConversionState(bool insertMissingEntityUnlocks) - { - this.ConvertFromProto = (P.HistoryEvent e) => ProtoUtils.ConvertHistoryEvent(e, this); - this.insertMissingEntityUnlocks = insertMissingEntityUnlocks; - } - - /// - /// Gets a function that converts a history event in protobuf format to a core history event. - /// - public Func ConvertFromProto { get; } - - /// - /// Gets the orchestration instance of this history. - /// - public OrchestrationInstance? CurrentInstance => this.instance; - - /// - /// Gets the set of guids that have been used as entity request ids in this history. - /// - public HashSet EntityRequestIds => this.entityRequestIds ??= new(); - - /// - /// Records the orchestration instance, which may be needed for some conversions. - /// - /// The orchestration instance. - public void SetOrchestrationInstance(OrchestrationInstance instance) - { - this.instance = instance; - } - - /// - /// Adds unlock obligations for all entities that are being locked by this request. - /// - /// The lock request. - public void AddUnlockObligations(P.EntityLockRequestedEvent request) - { - if (!this.insertMissingEntityUnlocks) - { - return; - } - - this.unlockObligations ??= new(); - - foreach (string target in request.LockSet) - { - this.unlockObligations[target] = request.CriticalSectionId; - } - } - - /// - /// Removes an unlock obligation. - /// - /// The target entity. - public void RemoveUnlockObligation(string target) - { - if (!this.insertMissingEntityUnlocks) - { - return; - } - - this.unlockObligations?.Remove(target); - } - - /// - /// Returns the remaining unlock obligations, and clears the list. - /// - /// The unlock obligations. - public IEnumerable<(string Target, string CriticalSectionId)> ResetObligations() - { - if (!this.insertMissingEntityUnlocks) - { - yield break; - } - - if (this.unlockObligations is not null) - { - foreach (var kvp in this.unlockObligations) - { - yield return (kvp.Key, kvp.Value); - } - - this.unlockObligations = null; - } - } - } -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Buffers; +using System.Buffers.Text; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Text; +using DurableTask.Core; +using DurableTask.Core.Command; +using DurableTask.Core.Entities; +using DurableTask.Core.Entities.OperationFormat; +using DurableTask.Core.History; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using DTCore = DurableTask.Core; +using P = Microsoft.DurableTask.Protobuf; + +namespace Microsoft.DurableTask; + +/// +/// Protobuf utilities and helpers. +/// +static class ProtoUtils +{ + /// + /// Converts a history event from to . + /// + /// The proto history event to converter. + /// The converted history event. + /// When the provided history event type is not supported. + internal static HistoryEvent ConvertHistoryEvent(P.HistoryEvent proto) + { + return ConvertHistoryEvent(proto, conversionState: null); + } + + /// + /// Converts a history event from to , and performs + /// stateful conversions of entity-related events. + /// + /// The proto history event to converter. + /// State needed for converting entity-related history entries and actions. + /// The converted history event. + /// When the provided history event type is not supported. + internal static HistoryEvent ConvertHistoryEvent(P.HistoryEvent proto, EntityConversionState? conversionState) + { + Check.NotNull(proto); + HistoryEvent historyEvent; + switch (proto.EventTypeCase) + { + case P.HistoryEvent.EventTypeOneofCase.ContinueAsNew: + historyEvent = new ContinueAsNewEvent(proto.EventId, proto.ContinueAsNew.Input); + break; + case P.HistoryEvent.EventTypeOneofCase.ExecutionStarted: + OrchestrationInstance instance = proto.ExecutionStarted.OrchestrationInstance.ToCore(); + conversionState?.SetOrchestrationInstance(instance); + historyEvent = new ExecutionStartedEvent(proto.EventId, proto.ExecutionStarted.Input) + { + Name = proto.ExecutionStarted.Name, + Version = proto.ExecutionStarted.Version, + OrchestrationInstance = instance, + Tags = proto.ExecutionStarted.Tags, + ParentInstance = proto.ExecutionStarted.ParentInstance == null ? null : new ParentInstance + { + Name = proto.ExecutionStarted.ParentInstance.Name, + Version = proto.ExecutionStarted.ParentInstance.Version, + OrchestrationInstance = proto.ExecutionStarted.ParentInstance.OrchestrationInstance.ToCore(), + TaskScheduleId = proto.ExecutionStarted.ParentInstance.TaskScheduledId, + }, + ScheduledStartTime = proto.ExecutionStarted.ScheduledStartTimestamp?.ToDateTime(), + }; + break; + case P.HistoryEvent.EventTypeOneofCase.ExecutionCompleted: + historyEvent = new ExecutionCompletedEvent( + proto.EventId, + proto.ExecutionCompleted.Result, + proto.ExecutionCompleted.OrchestrationStatus.ToCore()); + break; + case P.HistoryEvent.EventTypeOneofCase.ExecutionTerminated: + historyEvent = new ExecutionTerminatedEvent(proto.EventId, proto.ExecutionTerminated.Input); + break; + case P.HistoryEvent.EventTypeOneofCase.ExecutionSuspended: + historyEvent = new ExecutionSuspendedEvent(proto.EventId, proto.ExecutionSuspended.Input); + break; + case P.HistoryEvent.EventTypeOneofCase.ExecutionResumed: + historyEvent = new ExecutionResumedEvent(proto.EventId, proto.ExecutionResumed.Input); + break; + case P.HistoryEvent.EventTypeOneofCase.TaskScheduled: + historyEvent = new TaskScheduledEvent( + proto.EventId, + proto.TaskScheduled.Name, + proto.TaskScheduled.Version, + proto.TaskScheduled.Input); + break; + case P.HistoryEvent.EventTypeOneofCase.TaskCompleted: + historyEvent = new TaskCompletedEvent( + proto.EventId, + proto.TaskCompleted.TaskScheduledId, + proto.TaskCompleted.Result); + break; + case P.HistoryEvent.EventTypeOneofCase.TaskFailed: + historyEvent = new TaskFailedEvent( + proto.EventId, + proto.TaskFailed.TaskScheduledId, + reason: null, /* not supported */ + details: null, /* not supported */ + proto.TaskFailed.FailureDetails.ToCore()); + break; + case P.HistoryEvent.EventTypeOneofCase.SubOrchestrationInstanceCreated: + historyEvent = new SubOrchestrationInstanceCreatedEvent(proto.EventId) + { + Input = proto.SubOrchestrationInstanceCreated.Input, + InstanceId = proto.SubOrchestrationInstanceCreated.InstanceId, + Name = proto.SubOrchestrationInstanceCreated.Name, + Version = proto.SubOrchestrationInstanceCreated.Version, + }; + break; + case P.HistoryEvent.EventTypeOneofCase.SubOrchestrationInstanceCompleted: + historyEvent = new SubOrchestrationInstanceCompletedEvent( + proto.EventId, + proto.SubOrchestrationInstanceCompleted.TaskScheduledId, + proto.SubOrchestrationInstanceCompleted.Result); + break; + case P.HistoryEvent.EventTypeOneofCase.SubOrchestrationInstanceFailed: + historyEvent = new SubOrchestrationInstanceFailedEvent( + proto.EventId, + proto.SubOrchestrationInstanceFailed.TaskScheduledId, + reason: null /* not supported */, + details: null /* not supported */, + proto.SubOrchestrationInstanceFailed.FailureDetails.ToCore()); + break; + case P.HistoryEvent.EventTypeOneofCase.TimerCreated: + historyEvent = new TimerCreatedEvent( + proto.EventId, + proto.TimerCreated.FireAt.ToDateTime()); + break; + case P.HistoryEvent.EventTypeOneofCase.TimerFired: + historyEvent = new TimerFiredEvent( + eventId: -1, + proto.TimerFired.FireAt.ToDateTime()) + { + TimerId = proto.TimerFired.TimerId, + }; + break; + case P.HistoryEvent.EventTypeOneofCase.OrchestratorStarted: + historyEvent = new OrchestratorStartedEvent(proto.EventId); + break; + case P.HistoryEvent.EventTypeOneofCase.OrchestratorCompleted: + historyEvent = new OrchestratorCompletedEvent(proto.EventId); + break; + case P.HistoryEvent.EventTypeOneofCase.EventSent: + historyEvent = new EventSentEvent(proto.EventId) + { + InstanceId = proto.EventSent.InstanceId, + Name = proto.EventSent.Name, + Input = proto.EventSent.Input, + }; + break; + case P.HistoryEvent.EventTypeOneofCase.EventRaised: + historyEvent = new EventRaisedEvent(proto.EventId, proto.EventRaised.Input) + { + Name = proto.EventRaised.Name, + }; + break; + case P.HistoryEvent.EventTypeOneofCase.EntityOperationCalled: + historyEvent = EntityConversions.EncodeOperationCalled(proto, conversionState!.CurrentInstance); + conversionState?.EntityRequestIds.Add(proto.EntityOperationCalled.RequestId); + break; + case P.HistoryEvent.EventTypeOneofCase.EntityOperationSignaled: + historyEvent = EntityConversions.EncodeOperationSignaled(proto); + conversionState?.EntityRequestIds.Add(proto.EntityOperationSignaled.RequestId); + break; + case P.HistoryEvent.EventTypeOneofCase.EntityLockRequested: + historyEvent = EntityConversions.EncodeLockRequested(proto, conversionState!.CurrentInstance); + conversionState?.AddUnlockObligations(proto.EntityLockRequested); + break; + case P.HistoryEvent.EventTypeOneofCase.EntityUnlockSent: + historyEvent = EntityConversions.EncodeUnlockSent(proto, conversionState!.CurrentInstance); + conversionState?.RemoveUnlockObligation(proto.EntityUnlockSent.TargetInstanceId); + break; + case P.HistoryEvent.EventTypeOneofCase.EntityLockGranted: + historyEvent = EntityConversions.EncodeLockGranted(proto); + break; + case P.HistoryEvent.EventTypeOneofCase.EntityOperationCompleted: + historyEvent = EntityConversions.EncodeOperationCompleted(proto); + break; + case P.HistoryEvent.EventTypeOneofCase.EntityOperationFailed: + historyEvent = EntityConversions.EncodeOperationFailed(proto); + break; + case P.HistoryEvent.EventTypeOneofCase.GenericEvent: + historyEvent = new GenericEvent(proto.EventId, proto.GenericEvent.Data); + break; + case P.HistoryEvent.EventTypeOneofCase.HistoryState: + historyEvent = new HistoryStateEvent( + proto.EventId, + new OrchestrationState + { + OrchestrationInstance = new OrchestrationInstance + { + InstanceId = proto.HistoryState.OrchestrationState.InstanceId, + }, + Name = proto.HistoryState.OrchestrationState.Name, + Version = proto.HistoryState.OrchestrationState.Version, + ScheduledStartTime = proto.HistoryState.OrchestrationState.ScheduledStartTimestamp.ToDateTime(), + CreatedTime = proto.HistoryState.OrchestrationState.CreatedTimestamp.ToDateTime(), + LastUpdatedTime = proto.HistoryState.OrchestrationState.LastUpdatedTimestamp.ToDateTime(), + Input = proto.HistoryState.OrchestrationState.Input, + Output = proto.HistoryState.OrchestrationState.Output, + Status = proto.HistoryState.OrchestrationState.CustomStatus, + Tags = proto.HistoryState.OrchestrationState.Tags, + }); + break; + default: + throw new NotSupportedException($"Deserialization of {proto.EventTypeCase} is not supported."); + } + + historyEvent.Timestamp = proto.Timestamp.ToDateTime(); + return historyEvent; + } + + /// + /// Converts a to a gRPC . + /// + /// The date-time to convert. + /// The gRPC timestamp. + internal static Timestamp ToTimestamp(this DateTime dateTime) + { + // The protobuf libraries require timestamps to be in UTC + if (dateTime.Kind == DateTimeKind.Unspecified) + { + dateTime = DateTime.SpecifyKind(dateTime, DateTimeKind.Utc); + } + else if (dateTime.Kind == DateTimeKind.Local) + { + dateTime = dateTime.ToUniversalTime(); + } + + return Timestamp.FromDateTime(dateTime); + } + + /// + /// Converts a to a gRPC . + /// + /// The date-time to convert. + /// The gRPC timestamp. + internal static Timestamp? ToTimestamp(this DateTime? dateTime) + => dateTime.HasValue ? dateTime.Value.ToTimestamp() : null; + + /// + /// Converts a to a gRPC . + /// + /// The date-time to convert. + /// The gRPC timestamp. + internal static Timestamp ToTimestamp(this DateTimeOffset dateTime) => Timestamp.FromDateTimeOffset(dateTime); + + /// + /// Converts a to a gRPC . + /// + /// The date-time to convert. + /// The gRPC timestamp. + internal static Timestamp? ToTimestamp(this DateTimeOffset? dateTime) + => dateTime.HasValue ? dateTime.Value.ToTimestamp() : null; + + /// + /// Constructs a . + /// + /// The orchestrator instance ID. + /// The orchestrator customer status or null if no custom status. + /// The orchestrator actions. + /// + /// The completion token for the work item. It must be the exact same + /// value that was provided by the corresponding that triggered the orchestrator execution. + /// + /// The entity conversion state, or null if no conversion is required. + /// The orchestrator response. + /// When an orchestrator action is unknown. + internal static P.OrchestratorResponse ConstructOrchestratorResponse( + string instanceId, + string? customStatus, + IEnumerable actions, + string completionToken, + EntityConversionState? entityConversionState) + { + Check.NotNull(actions); + var response = new P.OrchestratorResponse + { + InstanceId = instanceId, + CustomStatus = customStatus, + CompletionToken = completionToken, + }; + + foreach (OrchestratorAction action in actions) + { + var protoAction = new P.OrchestratorAction { Id = action.Id }; + + switch (action.OrchestratorActionType) + { + case OrchestratorActionType.ScheduleOrchestrator: + var scheduleTaskAction = (ScheduleTaskOrchestratorAction)action; + protoAction.ScheduleTask = new P.ScheduleTaskAction + { + Name = scheduleTaskAction.Name, + Version = scheduleTaskAction.Version, + Input = scheduleTaskAction.Input, + }; + break; + case OrchestratorActionType.CreateSubOrchestration: + var subOrchestrationAction = (CreateSubOrchestrationAction)action; + protoAction.CreateSubOrchestration = new P.CreateSubOrchestrationAction + { + Input = subOrchestrationAction.Input, + InstanceId = subOrchestrationAction.InstanceId, + Name = subOrchestrationAction.Name, + Version = subOrchestrationAction.Version, + }; + break; + case OrchestratorActionType.CreateTimer: + var createTimerAction = (CreateTimerOrchestratorAction)action; + protoAction.CreateTimer = new P.CreateTimerAction + { + FireAt = createTimerAction.FireAt.ToTimestamp(), + }; + break; + case OrchestratorActionType.SendEvent: + var sendEventAction = (SendEventOrchestratorAction)action; + if (sendEventAction.Instance == null) + { + throw new ArgumentException( + $"{nameof(SendEventOrchestratorAction)} cannot have a null Instance property!"); + } + + if (entityConversionState is not null + && DTCore.Common.Entities.IsEntityInstance(sendEventAction.Instance.InstanceId) + && sendEventAction.EventName is not null + && sendEventAction.EventData is not null) + { + P.SendEntityMessageAction sendAction = new P.SendEntityMessageAction(); + protoAction.SendEntityMessage = sendAction; + + EntityConversions.DecodeEntityMessageAction( + sendEventAction.EventName, + sendEventAction.EventData, + sendEventAction.Instance.InstanceId, + sendAction, + out string requestId); + + entityConversionState.EntityRequestIds.Add(requestId); + + switch (sendAction.EntityMessageTypeCase) + { + case P.SendEntityMessageAction.EntityMessageTypeOneofCase.EntityLockRequested: + entityConversionState.AddUnlockObligations(sendAction.EntityLockRequested); + break; + case P.SendEntityMessageAction.EntityMessageTypeOneofCase.EntityUnlockSent: + entityConversionState.RemoveUnlockObligation(sendAction.EntityUnlockSent.TargetInstanceId); + break; + default: + break; + } + } + else + { + protoAction.SendEvent = new P.SendEventAction + { + Instance = sendEventAction.Instance.ToProtobuf(), + Name = sendEventAction.EventName, + Data = sendEventAction.EventData, + }; + } + + break; + case OrchestratorActionType.OrchestrationComplete: + + if (entityConversionState is not null) + { + // as a precaution, unlock any entities that were not unlocked for some reason, before + // completing the orchestration. + foreach ((string target, string criticalSectionId) in entityConversionState.ResetObligations()) + { + response.Actions.Add(new P.OrchestratorAction + { + Id = action.Id, + SendEntityMessage = new P.SendEntityMessageAction + { + EntityUnlockSent = new P.EntityUnlockSentEvent + { + CriticalSectionId = criticalSectionId, + TargetInstanceId = target, + ParentInstanceId = entityConversionState.CurrentInstance?.InstanceId, + }, + }, + }); + } + } + + var completeAction = (OrchestrationCompleteOrchestratorAction)action; + protoAction.CompleteOrchestration = new P.CompleteOrchestrationAction + { + CarryoverEvents = + { + // TODO + }, + Details = completeAction.Details, + NewVersion = completeAction.NewVersion, + OrchestrationStatus = completeAction.OrchestrationStatus.ToProtobuf(), + Result = completeAction.Result, + }; + + if (completeAction.OrchestrationStatus == OrchestrationStatus.Failed) + { + protoAction.CompleteOrchestration.FailureDetails = completeAction.FailureDetails.ToProtobuf(); + } + + break; + default: + throw new NotSupportedException($"Unknown orchestrator action: {action.OrchestratorActionType}"); + } + + response.Actions.Add(protoAction); + } + + return response; + } + + /// + /// Converts a to a . + /// + /// The status to convert. + /// The converted status. + internal static OrchestrationStatus ToCore(this P.OrchestrationStatus status) + { + return (OrchestrationStatus)status; + } + + /// + /// Converts a to a . + /// + /// The status to convert. + /// The converted status. + [return: NotNullIfNotNull(nameof(status))] + internal static OrchestrationInstance? ToCore(this P.OrchestrationInstance? status) + { + if (status == null) + { + return null; + } + + return new OrchestrationInstance + { + InstanceId = status.InstanceId, + ExecutionId = status.ExecutionId, + }; + } + + /// + /// Converts a to a . + /// + /// The failure details to convert. + /// The converted failure details. + [return: NotNullIfNotNull(nameof(failureDetails))] + internal static TaskFailureDetails? ToTaskFailureDetails(this P.TaskFailureDetails? failureDetails) + { + if (failureDetails == null) + { + return null; + } + + return new TaskFailureDetails( + failureDetails.ErrorType, + failureDetails.ErrorMessage, + failureDetails.StackTrace, + failureDetails.InnerFailure.ToTaskFailureDetails()); + } + + /// + /// Converts a to . + /// + /// The exception to convert. + /// The task failure details. + [return: NotNullIfNotNull(nameof(e))] + internal static P.TaskFailureDetails? ToTaskFailureDetails(this Exception? e) + { + if (e == null) + { + return null; + } + + return new P.TaskFailureDetails + { + ErrorType = e.GetType().FullName, + ErrorMessage = e.Message, + StackTrace = e.StackTrace, + InnerFailure = e.InnerException.ToTaskFailureDetails(), + }; + } + + /// + /// Converts a to a . + /// + /// The entity batch request to convert. + /// The converted entity batch request. + [return: NotNullIfNotNull(nameof(entityBatchRequest))] + internal static EntityBatchRequest? ToEntityBatchRequest(this P.EntityBatchRequest? entityBatchRequest) + { + if (entityBatchRequest == null) + { + return null; + } + + return new EntityBatchRequest() + { + EntityState = entityBatchRequest.EntityState, + InstanceId = entityBatchRequest.InstanceId, + Operations = entityBatchRequest.Operations.Select(r => r.ToOperationRequest()).ToList(), + }; + } + + /// + /// Converts a to a . + /// + /// The entity request to convert. + /// The converted request. + /// Additional info about each operation, required by DTS. + internal static void ToEntityBatchRequest( + this P.EntityRequest entityRequest, + out EntityBatchRequest batchRequest, + out List operationInfos) + { + batchRequest = new EntityBatchRequest() + { + EntityState = entityRequest.EntityState, + InstanceId = entityRequest.InstanceId, + Operations = [], // operations are added to this collection below + }; + + operationInfos = new(entityRequest.OperationRequests.Count); + + foreach (P.HistoryEvent? op in entityRequest.OperationRequests) + { + if (op.EntityOperationSignaled is not null) + { + batchRequest.Operations.Add(new OperationRequest + { + Id = Guid.Parse(op.EntityOperationSignaled.RequestId), + Operation = op.EntityOperationSignaled.Operation, + Input = op.EntityOperationSignaled.Input, + }); + operationInfos.Add(new P.OperationInfo + { + RequestId = op.EntityOperationSignaled.RequestId, + ResponseDestination = null, // means we don't send back a response to the caller + }); + } + else if (op.EntityOperationCalled is not null) + { + batchRequest.Operations.Add(new OperationRequest + { + Id = Guid.Parse(op.EntityOperationCalled.RequestId), + Operation = op.EntityOperationCalled.Operation, + Input = op.EntityOperationCalled.Input, + }); + operationInfos.Add(new P.OperationInfo + { + RequestId = op.EntityOperationCalled.RequestId, + ResponseDestination = new P.OrchestrationInstance + { + InstanceId = op.EntityOperationCalled.ParentInstanceId, + ExecutionId = op.EntityOperationCalled.ParentExecutionId, + }, + }); + } + } + } + + /// + /// Converts a to a . + /// + /// The operation request to convert. + /// The converted operation request. + [return: NotNullIfNotNull(nameof(operationRequest))] + internal static OperationRequest? ToOperationRequest(this P.OperationRequest? operationRequest) + { + if (operationRequest == null) + { + return null; + } + + return new OperationRequest() + { + Operation = operationRequest.Operation, + Input = operationRequest.Input, + Id = Guid.Parse(operationRequest.RequestId), + }; + } + + /// + /// Converts a to a . + /// + /// The operation result to convert. + /// The converted operation result. + [return: NotNullIfNotNull(nameof(operationResult))] + internal static OperationResult? ToOperationResult(this P.OperationResult? operationResult) + { + if (operationResult == null) + { + return null; + } + + switch (operationResult.ResultTypeCase) + { + case P.OperationResult.ResultTypeOneofCase.Success: + return new OperationResult() + { + Result = operationResult.Success.Result, + }; + + case P.OperationResult.ResultTypeOneofCase.Failure: + return new OperationResult() + { + FailureDetails = operationResult.Failure.FailureDetails.ToCore(), + }; + + default: + throw new NotSupportedException($"Deserialization of {operationResult.ResultTypeCase} is not supported."); + } + } + + /// + /// Converts a to . + /// + /// The operation result to convert. + /// The converted operation result. + [return: NotNullIfNotNull(nameof(operationResult))] + internal static P.OperationResult? ToOperationResult(this OperationResult? operationResult) + { + if (operationResult == null) + { + return null; + } + + if (operationResult.FailureDetails == null) + { + return new P.OperationResult() + { + Success = new P.OperationResultSuccess() + { + Result = operationResult.Result, + }, + }; + } + else + { + return new P.OperationResult() + { + Failure = new P.OperationResultFailure() + { + FailureDetails = ToProtobuf(operationResult.FailureDetails), + }, + }; + } + } + + /// + /// Converts a to a . + /// + /// The operation action to convert. + /// The converted operation action. + [return: NotNullIfNotNull(nameof(operationAction))] + internal static OperationAction? ToOperationAction(this P.OperationAction? operationAction) + { + if (operationAction == null) + { + return null; + } + + switch (operationAction.OperationActionTypeCase) + { + case P.OperationAction.OperationActionTypeOneofCase.SendSignal: + + return new SendSignalOperationAction() + { + Name = operationAction.SendSignal.Name, + Input = operationAction.SendSignal.Input, + InstanceId = operationAction.SendSignal.InstanceId, + ScheduledTime = operationAction.SendSignal.ScheduledTime?.ToDateTime(), + }; + + case P.OperationAction.OperationActionTypeOneofCase.StartNewOrchestration: + + return new StartNewOrchestrationOperationAction() + { + Name = operationAction.StartNewOrchestration.Name, + Input = operationAction.StartNewOrchestration.Input, + InstanceId = operationAction.StartNewOrchestration.InstanceId, + Version = operationAction.StartNewOrchestration.Version, + ScheduledStartTime = operationAction.StartNewOrchestration.ScheduledTime?.ToDateTime(), + }; + default: + throw new NotSupportedException($"Deserialization of {operationAction.OperationActionTypeCase} is not supported."); + } + } + + /// + /// Converts a to . + /// + /// The operation action to convert. + /// The converted operation action. + [return: NotNullIfNotNull(nameof(operationAction))] + internal static P.OperationAction? ToOperationAction(this OperationAction? operationAction) + { + if (operationAction == null) + { + return null; + } + + var action = new P.OperationAction(); + + switch (operationAction) + { + case SendSignalOperationAction sendSignalAction: + + action.SendSignal = new P.SendSignalAction() + { + Name = sendSignalAction.Name, + Input = sendSignalAction.Input, + InstanceId = sendSignalAction.InstanceId, + ScheduledTime = sendSignalAction.ScheduledTime?.ToTimestamp(), + }; + break; + + case StartNewOrchestrationOperationAction startNewOrchestrationAction: + + action.StartNewOrchestration = new P.StartNewOrchestrationAction() + { + Name = startNewOrchestrationAction.Name, + Input = startNewOrchestrationAction.Input, + Version = startNewOrchestrationAction.Version, + InstanceId = startNewOrchestrationAction.InstanceId, + ScheduledTime = startNewOrchestrationAction.ScheduledStartTime?.ToTimestamp(), + }; + break; + } + + return action; + } + + /// + /// Converts a to a . + /// + /// The operation result to convert. + /// The converted operation result. + [return: NotNullIfNotNull(nameof(entityBatchResult))] + internal static EntityBatchResult? ToEntityBatchResult(this P.EntityBatchResult? entityBatchResult) + { + if (entityBatchResult == null) + { + return null; + } + + return new EntityBatchResult() + { + Actions = entityBatchResult.Actions.Select(operationAction => operationAction!.ToOperationAction()).ToList(), + EntityState = entityBatchResult.EntityState, + Results = entityBatchResult.Results.Select(operationResult => operationResult!.ToOperationResult()).ToList(), + FailureDetails = entityBatchResult.FailureDetails.ToCore(), + }; + } + + /// + /// Converts a to . + /// + /// The operation result to convert. + /// The completion token, or null for the older protocol. + /// Additional information about each operation, required by DTS. + /// The converted operation result. + [return: NotNullIfNotNull(nameof(entityBatchResult))] + internal static P.EntityBatchResult? ToEntityBatchResult( + this EntityBatchResult? entityBatchResult, + string? completionToken = null, + IEnumerable? operationInfos = null) + { + if (entityBatchResult == null) + { + return null; + } + + return new P.EntityBatchResult() + { + EntityState = entityBatchResult.EntityState, + FailureDetails = entityBatchResult.FailureDetails.ToProtobuf(), + Actions = { entityBatchResult.Actions?.Select(a => a.ToOperationAction()) ?? [] }, + Results = { entityBatchResult.Results?.Select(a => a.ToOperationResult()) ?? [] }, + CompletionToken = completionToken ?? string.Empty, + OperationInfos = { operationInfos ?? [] }, + }; + } + + /// + /// Converts the gRPC representation of orchestrator entity parameters to the DT.Core representation. + /// + /// The DT.Core representation. + /// The gRPC representation. + [return: NotNullIfNotNull(nameof(parameters))] + internal static TaskOrchestrationEntityParameters? ToCore(this P.OrchestratorEntityParameters? parameters) + { + if (parameters == null) + { + return null; + } + + return new TaskOrchestrationEntityParameters() + { + EntityMessageReorderWindow = parameters.EntityMessageReorderWindow.ToTimeSpan(), + }; + } + + /// + /// Gets the approximate byte count for a . + /// + /// The failure details. + /// The approximate byte count. + internal static int GetApproximateByteCount(this P.TaskFailureDetails failureDetails) + { + // Protobuf strings are always UTF-8: https://developers.google.com/protocol-buffers/docs/proto3#scalar + Encoding encoding = Encoding.UTF8; + + int byteCount = 0; + if (failureDetails.ErrorType != null) + { + byteCount += encoding.GetByteCount(failureDetails.ErrorType); + } + + if (failureDetails.ErrorMessage != null) + { + byteCount += encoding.GetByteCount(failureDetails.ErrorMessage); + } + + if (failureDetails.StackTrace != null) + { + byteCount += encoding.GetByteCount(failureDetails.StackTrace); + } + + if (failureDetails.InnerFailure != null) + { + byteCount += failureDetails.InnerFailure.GetApproximateByteCount(); + } + + return byteCount; + } + + /// + /// Decode a protobuf message from a base64 string. + /// + /// The type to decode to. + /// The message parser. + /// The base64 encoded message. + /// The decoded message. + /// If decoding fails. + internal static T Base64Decode(this MessageParser parser, string encodedMessage) where T : IMessage + { + // Decode the base64 in a way that doesn't allocate a byte[] on each request + int encodedByteCount = Encoding.UTF8.GetByteCount(encodedMessage); + byte[] buffer = ArrayPool.Shared.Rent(encodedByteCount); + try + { + // The Base64 APIs require first converting the string into UTF-8 bytes. We then + // do an in-place conversion from base64 UTF-8 bytes to protobuf bytes so that + // we can finally decode the protobuf request. + Encoding.UTF8.GetBytes(encodedMessage, 0, encodedMessage.Length, buffer, 0); + OperationStatus status = Base64.DecodeFromUtf8InPlace( + buffer.AsSpan(0, encodedByteCount), + out int bytesWritten); + if (status != OperationStatus.Done) + { + throw new ArgumentException( + $"Failed to base64-decode the '{typeof(T).Name}' payload: {status}", nameof(encodedMessage)); + } + + return (T)parser.ParseFrom(buffer, 0, bytesWritten); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + /// + /// Converts a grpc to a . + /// + /// The failure details to convert. + /// The converted failure details. + internal static FailureDetails? ToCore(this P.TaskFailureDetails? failureDetails) + { + if (failureDetails == null) + { + return null; + } + + return new FailureDetails( + failureDetails.ErrorType, + failureDetails.ErrorMessage, + failureDetails.StackTrace, + failureDetails.InnerFailure.ToCore(), + failureDetails.IsNonRetriable); + } + + /// + /// Converts a to a grpc . + /// + /// The failure details to convert. + /// The converted failure details. + static P.TaskFailureDetails? ToProtobuf(this FailureDetails? failureDetails) + { + if (failureDetails == null) + { + return null; + } + + return new P.TaskFailureDetails + { + ErrorType = failureDetails.ErrorType ?? "(unknown)", + ErrorMessage = failureDetails.ErrorMessage ?? "(unknown)", + StackTrace = failureDetails.StackTrace, + IsNonRetriable = failureDetails.IsNonRetriable, + InnerFailure = failureDetails.InnerFailure.ToProtobuf(), + }; + } + + static P.OrchestrationStatus ToProtobuf(this OrchestrationStatus status) + { + return (P.OrchestrationStatus)status; + } + + static P.OrchestrationInstance ToProtobuf(this OrchestrationInstance instance) + { + return new P.OrchestrationInstance + { + InstanceId = instance.InstanceId, + ExecutionId = instance.ExecutionId, + }; + } + + /// + /// Tracks state required for converting orchestration histories containing entity-related events. + /// + internal class EntityConversionState + { + readonly bool insertMissingEntityUnlocks; + + OrchestrationInstance? instance; + HashSet? entityRequestIds; + Dictionary? unlockObligations; + + /// + /// Initializes a new instance of the class. + /// + /// Whether to insert missing unlock events in to the history + /// when the orchestration completes. + public EntityConversionState(bool insertMissingEntityUnlocks) + { + this.ConvertFromProto = (P.HistoryEvent e) => ProtoUtils.ConvertHistoryEvent(e, this); + this.insertMissingEntityUnlocks = insertMissingEntityUnlocks; + } + + /// + /// Gets a function that converts a history event in protobuf format to a core history event. + /// + public Func ConvertFromProto { get; } + + /// + /// Gets the orchestration instance of this history. + /// + public OrchestrationInstance? CurrentInstance => this.instance; + + /// + /// Gets the set of guids that have been used as entity request ids in this history. + /// + public HashSet EntityRequestIds => this.entityRequestIds ??= new(); + + /// + /// Records the orchestration instance, which may be needed for some conversions. + /// + /// The orchestration instance. + public void SetOrchestrationInstance(OrchestrationInstance instance) + { + this.instance = instance; + } + + /// + /// Adds unlock obligations for all entities that are being locked by this request. + /// + /// The lock request. + public void AddUnlockObligations(P.EntityLockRequestedEvent request) + { + if (!this.insertMissingEntityUnlocks) + { + return; + } + + this.unlockObligations ??= new(); + + foreach (string target in request.LockSet) + { + this.unlockObligations[target] = request.CriticalSectionId; + } + } + + /// + /// Removes an unlock obligation. + /// + /// The target entity. + public void RemoveUnlockObligation(string target) + { + if (!this.insertMissingEntityUnlocks) + { + return; + } + + this.unlockObligations?.Remove(target); + } + + /// + /// Returns the remaining unlock obligations, and clears the list. + /// + /// The unlock obligations. + public IEnumerable<(string Target, string CriticalSectionId)> ResetObligations() + { + if (!this.insertMissingEntityUnlocks) + { + yield break; + } + + if (this.unlockObligations is not null) + { + foreach (var kvp in this.unlockObligations) + { + yield return (kvp.Key, kvp.Value); + } + + this.unlockObligations = null; + } + } + } +} diff --git a/test/Client/OrchestrationServiceClientShim.Tests/ShimDurableTaskClientTests.cs b/test/Client/OrchestrationServiceClientShim.Tests/ShimDurableTaskClientTests.cs index e8db1b574..cf67af96b 100644 --- a/test/Client/OrchestrationServiceClientShim.Tests/ShimDurableTaskClientTests.cs +++ b/test/Client/OrchestrationServiceClientShim.Tests/ShimDurableTaskClientTests.cs @@ -1,467 +1,483 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using DurableTask.Core; -using DurableTask.Core.Entities; -using DurableTask.Core.History; -using DurableTask.Core.Query; -using FluentAssertions.Specialized; -using Microsoft.DurableTask.Client.Entities; -using Microsoft.DurableTask.Converters; -using Microsoft.Extensions.Options; -using Core = DurableTask.Core; -using CoreOrchestrationQuery = DurableTask.Core.Query.OrchestrationQuery; -using PurgeInstanceFilter = Microsoft.DurableTask.Client.PurgeInstancesFilter; - -namespace Microsoft.DurableTask.Client.OrchestrationServiceClientShim.Tests; - -public class ShimDurableTaskClientTests -{ - readonly ShimDurableTaskClient client; - readonly Mock orchestrationClient = new(MockBehavior.Strict); - readonly Mock queryClient; - readonly Mock purgeClient; - - public ShimDurableTaskClientTests() - { - this.queryClient = this.orchestrationClient.As(); - this.purgeClient = this.orchestrationClient.As(); - this.client = new("test", new ShimDurableTaskClientOptions { Client = this.orchestrationClient.Object }); - } - - [Fact] - public void Ctor_NullOptions_Throws1() - { - IOptionsMonitor options = null!; - Func act = () => new ShimDurableTaskClient("test", options); - act.Should().ThrowExactly().WithParameterName("options"); - - options = Mock.Of>(); - act = () => new ShimDurableTaskClient("test", options); - act.Should().ThrowExactly().WithParameterName("options"); - } - - [Fact] - public void Ctor_NullOptions_Throws2() - { - IOptionsMonitor options = - Mock.Of>(); - Func act = () => new ShimDurableTaskClient("test", options); - act.Should().ThrowExactly().WithParameterName("options"); - } - - [Fact] - public void Ctor_NoEntitySupport_GetClientThrows() - { - IOrchestrationServiceClient client = Mock.Of(); - ShimDurableTaskClientOptions options = new() { Client = client }; - ShimDurableTaskClient shimClient = new("test", options); - - Func act = () => shimClient.Entities; - act.Should().ThrowExactly().WithMessage("Entity support is not enabled."); - } - - [Fact] - public void Ctor_InvalidEntityOptions_GetClientThrows() - { - IOrchestrationServiceClient client = Mock.Of(); - ShimDurableTaskClientOptions options = new() { Client = client, EnableEntitySupport = true }; - ShimDurableTaskClient shimClient = new("test", options); - - Func act = () => shimClient.Entities; - act.Should().ThrowExactly() - .WithMessage("The configured IOrchestrationServiceClient does not support entities."); - } - - [Fact] - public void Ctor_EntitiesConfigured_GetClientSuccess() - { - IOrchestrationServiceClient client = Mock.Of(); - EntityBackendQueries queries = Mock.Of(); - ShimDurableTaskClientOptions options = new() - { - Client = client, - EnableEntitySupport = true, - Entities = { Queries = queries }, - }; - - ShimDurableTaskClient shimClient = new("test", options); - DurableEntityClient entityClient = shimClient.Entities; - - entityClient.Should().BeOfType(); - } - - [Theory] - [InlineData(false)] - [InlineData(true)] - public async void GetInstanceMetadata_EmptyList_Null(bool isNull) - { - // arrange - List? states = isNull ? null : new(); - string instanceId = Guid.NewGuid().ToString(); - this.orchestrationClient.Setup(m => m.GetOrchestrationStateAsync(instanceId, false)).ReturnsAsync(states); - - // act - OrchestrationMetadata? result = await this.client.GetInstanceAsync(instanceId, false); - - // assert - result.Should().BeNull(); - } - - [Theory] - [InlineData(false)] - [InlineData(true)] - public async Task GetInstanceMetadata_Results(bool getInputs) - { - // arrange - List states = new() { CreateState("input") }; - string instanceId = states.First().OrchestrationInstance.InstanceId; - this.orchestrationClient.Setup(m => m.GetOrchestrationStateAsync(instanceId, false)).ReturnsAsync(states); - - // act - OrchestrationMetadata? result = await this.client.GetInstanceAsync(instanceId, getInputs); - - // assert - Validate(result, states.First(), getInputs); - } - - [Theory] - [InlineData(false)] - [InlineData(true)] - public async Task GetInstances_Results(bool getInputs) - { - // arrange - DateTimeOffset utcNow = DateTimeOffset.UtcNow; - List states = - [ - CreateState("input", start: utcNow.AddMinutes(-1)), - CreateState(10, "output", utcNow.AddMinutes(-5)), - ]; - - OrchestrationQueryResult queryResult = new(states, null); - string instanceId = states.First().OrchestrationInstance.InstanceId; - this.queryClient - .Setup(m => m.GetOrchestrationWithQueryAsync(It.IsNotNull(), default)) - .ReturnsAsync(queryResult); - - OrchestrationQuery query = new() - { - CreatedFrom = utcNow.AddMinutes(-10), - FetchInputsAndOutputs = getInputs, - }; - - // act - List result = await this.client.GetAllInstancesAsync(query).ToListAsync(); - - // assert - this.orchestrationClient.VerifyAll(); - foreach ((OrchestrationMetadata left, Core.OrchestrationState right) in result.Zip(states)) - { - Validate(left, right, getInputs); - } - } - - [Fact] - public async Task PurgeInstanceMetadata() - { - // arrange - string instanceId = Guid.NewGuid().ToString(); - this.purgeClient.Setup(m => m.PurgeInstanceStateAsync(instanceId)).ReturnsAsync(new Core.PurgeResult(1)); - - // act - PurgeResult result = await this.client.PurgeInstanceAsync(instanceId); - - // assert - this.orchestrationClient.VerifyAll(); - result.PurgedInstanceCount.Should().Be(1); - } - - [Fact] - public async Task PurgeInstances() - { - // arrange - PurgeInstanceFilter filter = new(CreatedTo: DateTimeOffset.UtcNow); - this.purgeClient.Setup(m => m.PurgeInstanceStateAsync(It.IsNotNull())) - .ReturnsAsync(new Core.PurgeResult(10)); - - // act - PurgeResult result = await this.client.PurgeAllInstancesAsync(filter); - - // assert - this.orchestrationClient.VerifyAll(); - result.PurgedInstanceCount.Should().Be(10); - } - - [Fact] - public async Task RaiseEvent() - { - // arrange - string instanceId = Guid.NewGuid().ToString(); - this.SetupClientTaskMessage(instanceId); - - // act - await this.client.RaiseEventAsync(instanceId, "test-event", null, default); - - // assert - this.orchestrationClient.VerifyAll(); - } - - [Fact] - public async Task SuspendInstance() - { - // arrange - string instanceId = Guid.NewGuid().ToString(); - this.SetupClientTaskMessage(instanceId); - - // act - await this.client.SuspendInstanceAsync(instanceId, null, default); - - // assert - this.orchestrationClient.VerifyAll(); - } - - [Fact] - public async Task ResumeInstance() - { - // arrange - string instanceId = Guid.NewGuid().ToString(); - this.SetupClientTaskMessage(instanceId); - - // act - await this.client.ResumeInstanceAsync(instanceId, null, default); - - // assert - this.orchestrationClient.VerifyAll(); - } - - [Fact] - public async Task TerminateInstance() - { - // arrange - string instanceId = Guid.NewGuid().ToString(); - this.orchestrationClient.Setup(m => m.ForceTerminateTaskOrchestrationAsync(instanceId, null)) - .Returns(Task.CompletedTask); - - // act - await this.client.TerminateInstanceAsync(instanceId, null, default); - - // assert - this.orchestrationClient.VerifyAll(); - } - - [Fact] - public async Task WaitForInstanceCompletion() - { - // arrange - Core.OrchestrationState state = CreateState("input", "output"); - this.orchestrationClient.Setup( - m => m.WaitForOrchestrationAsync(state.OrchestrationInstance.InstanceId, null, TimeSpan.MaxValue, default)) - .ReturnsAsync(state); - - // act - OrchestrationMetadata metadata = await this.client.WaitForInstanceCompletionAsync( - state.OrchestrationInstance.InstanceId, false, default); - - // assert - this.orchestrationClient.VerifyAll(); - Validate(metadata, state, false); - } - - [Fact] - public async Task WaitForInstanceStart() - { - // arrange - DateTimeOffset start = DateTimeOffset.UtcNow; - OrchestrationInstance instance = new() - { - InstanceId = Guid.NewGuid().ToString(), - ExecutionId = Guid.NewGuid().ToString(), - }; - - Core.OrchestrationState state1 = CreateState("input", start: start); - state1.OrchestrationInstance = instance; - state1.OrchestrationStatus = Core.OrchestrationStatus.Pending; - Core.OrchestrationState state2 = CreateState("input", start: start); - state1.OrchestrationInstance = instance; - this.orchestrationClient.SetupSequence(m => m.GetOrchestrationStateAsync(instance.InstanceId, false)) - .ReturnsAsync([state1]) - .ReturnsAsync([state2]); - - // act - OrchestrationMetadata metadata = await this.client.WaitForInstanceStartAsync( - instance.InstanceId, false, default); - - // assert - this.orchestrationClient.Verify( - m => m.GetOrchestrationStateAsync(instance.InstanceId, false), Times.Exactly(2)); - Validate(metadata, state2, false); - } - - [Fact] - public Task ScheduleNewOrchestrationInstance_IdGenerated_NoInput() - => this.RunScheduleNewOrchestrationInstanceAsync("test", null, null); - - [Fact] - public Task ScheduleNewOrchestrationInstance_IdProvided_InputProvided() - => this.RunScheduleNewOrchestrationInstanceAsync("test", "input", new() { InstanceId = "test-id" }); - - [Fact] - public Task ScheduleNewOrchestrationInstance_StartAt() - => this.RunScheduleNewOrchestrationInstanceAsync( - "test", null, new() { StartAt = DateTimeOffset.UtcNow.AddHours(1) }); - - static Core.OrchestrationState CreateState( - object input, object? output = null, DateTimeOffset start = default) - { - string? serializedOutput = null; - FailureDetails? failure = null; - OrchestrationStatus status = OrchestrationStatus.Running; - - if (start == default) - { - start = DateTimeOffset.UtcNow.AddMinutes(-10); - } - - switch (output) - { - case Exception ex: - status = OrchestrationStatus.Failed; - failure = new(ex.GetType().FullName!, ex.Message, ex.StackTrace, null, true); - break; - case not null: - status = OrchestrationStatus.Completed; - serializedOutput = JsonDataConverter.Default.Serialize(output); - break; - } - - return new() - { - CompletedTime = default, - CreatedTime = start.UtcDateTime, - LastUpdatedTime = start.AddMinutes(10).UtcDateTime, - Input = JsonDataConverter.Default.Serialize(input), - Output = serializedOutput, - FailureDetails = failure, - Name = "test-orchestration", - OrchestrationInstance = new() - { - InstanceId = Guid.NewGuid().ToString(), - ExecutionId = Guid.NewGuid().ToString(), - }, - OrchestrationStatus = status, - Status = JsonDataConverter.Default.Serialize("custom-status"), - Version = string.Empty, - }; - } - - static TaskMessage MatchStartExecutionMessage(TaskName name, object? input, StartOrchestrationOptions? options) - { - return Match.Create(m => - { - if (m.Event is not ExecutionStartedEvent @event) - { - return false; - } - - - if (options?.InstanceId is string str && m.OrchestrationInstance.InstanceId != str) - { - return false; - } - else if (options?.InstanceId is null && !Guid.TryParse(m.OrchestrationInstance.InstanceId, out _)) - { - return false; - } - - if (options?.StartAt is DateTimeOffset start && @event.ScheduledStartTime != start.UtcDateTime) - { - return false; - } - else if (options?.StartAt is null && @event.ScheduledStartTime is not null) - { - return false; - } - - return Guid.TryParse(m.OrchestrationInstance.ExecutionId, out _) - && @event.Name == name.Name && @event.Version == name.Version - && @event.OrchestrationInstance == m.OrchestrationInstance - && @event.EventId == -1 - && @event.Input == JsonDataConverter.Default.Serialize(input); - }); - } - - static void Validate(OrchestrationMetadata? metadata, Core.OrchestrationState? state, bool getInputs) - { - if (state is null) - { - metadata.Should().BeNull(); - return; - } - - metadata.Should().NotBeNull(); - metadata!.Name.Should().Be(state.Name); - metadata.InstanceId.Should().Be(state.OrchestrationInstance.InstanceId); - metadata.RuntimeStatus.Should().Be(state.OrchestrationStatus.ConvertFromCore()); - metadata.CreatedAt.Should().Be(new DateTimeOffset(state.CreatedTime)); - metadata.LastUpdatedAt.Should().Be(new DateTimeOffset(state.LastUpdatedTime)); - metadata.SerializedInput.Should().Be(state.Input); - metadata.SerializedOutput.Should().Be(state.Output); - metadata.SerializedCustomStatus.Should().Be(state.Status); - - if (getInputs) - { - metadata.Invoking(m => m.ReadInputAs()).Should().NotThrow(); - } - } - - static void Validate(TaskFailureDetails? left, FailureDetails? right) - { - if (right is null) - { - left.Should().BeNull(); - return; - } - - left.Should().NotBeNull(); - left!.ErrorType.Should().Be(right.ErrorType); - left.ErrorMessage.Should().Be(right.ErrorMessage); - left.StackTrace.Should().Be(right.StackTrace); - Validate(left.InnerFailure, right.InnerFailure); - } - - void SetupClientTaskMessage(string instanceId) - where TEvent : HistoryEvent - { - this.orchestrationClient - .Setup(m => m.SendTaskOrchestrationMessageAsync(It.Is(m => - m.OrchestrationInstance.InstanceId == instanceId && m.Event.GetType() == typeof(TEvent)) - )) - .Returns(Task.CompletedTask); - } - - async Task RunScheduleNewOrchestrationInstanceAsync( - TaskName name, object? input, StartOrchestrationOptions? options) - { - // arrange - this.orchestrationClient.Setup( - m => m.CreateTaskOrchestrationAsync(MatchStartExecutionMessage(name, input, options))) - .Returns(Task.CompletedTask); - - // act - string instanceId = await this.client.ScheduleNewOrchestrationInstanceAsync(name, input, options, default); - - // assert - this.orchestrationClient.Verify( - m => m.CreateTaskOrchestrationAsync(MatchStartExecutionMessage(name, input, options)), - Times.Once()); - - if (options?.InstanceId is string str) - { - instanceId.Should().Be(str); - } - else - { - Guid.TryParse(instanceId, out _).Should().BeTrue(); - } - } +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core; +using DurableTask.Core.Entities; +using DurableTask.Core.History; +using DurableTask.Core.Query; +using FluentAssertions.Specialized; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Converters; +using Microsoft.Extensions.Options; +using Core = DurableTask.Core; +using CoreOrchestrationQuery = DurableTask.Core.Query.OrchestrationQuery; +using PurgeInstanceFilter = Microsoft.DurableTask.Client.PurgeInstancesFilter; + +namespace Microsoft.DurableTask.Client.OrchestrationServiceClientShim.Tests; + +public class ShimDurableTaskClientTests +{ + readonly ShimDurableTaskClient client; + readonly Mock orchestrationClient = new(MockBehavior.Strict); + readonly Mock queryClient; + readonly Mock purgeClient; + + public ShimDurableTaskClientTests() + { + this.queryClient = this.orchestrationClient.As(); + this.purgeClient = this.orchestrationClient.As(); + this.client = new("test", new ShimDurableTaskClientOptions { Client = this.orchestrationClient.Object }); + } + + [Fact] + public void Ctor_NullOptions_Throws1() + { + IOptionsMonitor options = null!; + Func act = () => new ShimDurableTaskClient("test", options); + act.Should().ThrowExactly().WithParameterName("options"); + + options = Mock.Of>(); + act = () => new ShimDurableTaskClient("test", options); + act.Should().ThrowExactly().WithParameterName("options"); + } + + [Fact] + public void Ctor_NullOptions_Throws2() + { + IOptionsMonitor options = + Mock.Of>(); + Func act = () => new ShimDurableTaskClient("test", options); + act.Should().ThrowExactly().WithParameterName("options"); + } + + [Fact] + public void Ctor_NoEntitySupport_GetClientThrows() + { + IOrchestrationServiceClient client = Mock.Of(); + ShimDurableTaskClientOptions options = new() { Client = client }; + ShimDurableTaskClient shimClient = new("test", options); + + Func act = () => shimClient.Entities; + act.Should().ThrowExactly().WithMessage("Entity support is not enabled."); + } + + [Fact] + public void Ctor_InvalidEntityOptions_GetClientThrows() + { + IOrchestrationServiceClient client = Mock.Of(); + ShimDurableTaskClientOptions options = new() { Client = client, EnableEntitySupport = true }; + ShimDurableTaskClient shimClient = new("test", options); + + Func act = () => shimClient.Entities; + act.Should().ThrowExactly() + .WithMessage("The configured IOrchestrationServiceClient does not support entities."); + } + + [Fact] + public void Ctor_EntitiesConfigured_GetClientSuccess() + { + IOrchestrationServiceClient client = Mock.Of(); + EntityBackendQueries queries = Mock.Of(); + ShimDurableTaskClientOptions options = new() + { + Client = client, + EnableEntitySupport = true, + Entities = { Queries = queries }, + }; + + ShimDurableTaskClient shimClient = new("test", options); + DurableEntityClient entityClient = shimClient.Entities; + + entityClient.Should().BeOfType(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async void GetInstanceMetadata_EmptyList_Null(bool isNull) + { + // arrange + List? states = isNull ? null : new(); + string instanceId = Guid.NewGuid().ToString(); + this.orchestrationClient.Setup(m => m.GetOrchestrationStateAsync(instanceId, false)).ReturnsAsync(states); + + // act + OrchestrationMetadata? result = await this.client.GetInstanceAsync(instanceId, false); + + // assert + result.Should().BeNull(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task GetInstanceMetadata_Results(bool getInputs) + { + // arrange + List states = new() { CreateState("input") }; + string instanceId = states.First().OrchestrationInstance.InstanceId; + this.orchestrationClient.Setup(m => m.GetOrchestrationStateAsync(instanceId, false)).ReturnsAsync(states); + + // act + OrchestrationMetadata? result = await this.client.GetInstanceAsync(instanceId, getInputs); + + // assert + Validate(result, states.First(), getInputs); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task GetInstances_Results(bool getInputs) + { + // arrange + DateTimeOffset utcNow = DateTimeOffset.UtcNow; + List states = + [ + CreateState("input", start: utcNow.AddMinutes(-1)), + CreateState(10, "output", utcNow.AddMinutes(-5)), + ]; + + OrchestrationQueryResult queryResult = new(states, null); + string instanceId = states.First().OrchestrationInstance.InstanceId; + this.queryClient + .Setup(m => m.GetOrchestrationWithQueryAsync(It.IsNotNull(), default)) + .ReturnsAsync(queryResult); + + OrchestrationQuery query = new() + { + CreatedFrom = utcNow.AddMinutes(-10), + FetchInputsAndOutputs = getInputs, + }; + + // act + List result = await this.client.GetAllInstancesAsync(query).ToListAsync(); + + // assert + this.orchestrationClient.VerifyAll(); + foreach ((OrchestrationMetadata left, Core.OrchestrationState right) in result.Zip(states)) + { + Validate(left, right, getInputs); + } + } + + [Fact] + public async Task PurgeInstanceMetadata() + { + // arrange + string instanceId = Guid.NewGuid().ToString(); + this.purgeClient.Setup(m => m.PurgeInstanceStateAsync(instanceId)).ReturnsAsync(new Core.PurgeResult(1)); + + // act + PurgeResult result = await this.client.PurgeInstanceAsync(instanceId); + + // assert + this.orchestrationClient.VerifyAll(); + result.PurgedInstanceCount.Should().Be(1); + } + + [Fact] + public async Task PurgeInstances() + { + // arrange + PurgeInstanceFilter filter = new(CreatedTo: DateTimeOffset.UtcNow); + this.purgeClient.Setup(m => m.PurgeInstanceStateAsync(It.IsNotNull())) + .ReturnsAsync(new Core.PurgeResult(10)); + + // act + PurgeResult result = await this.client.PurgeAllInstancesAsync(filter); + + // assert + this.orchestrationClient.VerifyAll(); + result.PurgedInstanceCount.Should().Be(10); + } + + [Fact] + public async Task RaiseEvent() + { + // arrange + string instanceId = Guid.NewGuid().ToString(); + this.SetupClientTaskMessage(instanceId); + + // act + await this.client.RaiseEventAsync(instanceId, "test-event", null, default); + + // assert + this.orchestrationClient.VerifyAll(); + } + + [Fact] + public async Task SuspendInstance() + { + // arrange + string instanceId = Guid.NewGuid().ToString(); + this.SetupClientTaskMessage(instanceId); + + // act + await this.client.SuspendInstanceAsync(instanceId, null, default); + + // assert + this.orchestrationClient.VerifyAll(); + } + + [Fact] + public async Task ResumeInstance() + { + // arrange + string instanceId = Guid.NewGuid().ToString(); + this.SetupClientTaskMessage(instanceId); + + // act + await this.client.ResumeInstanceAsync(instanceId, null, default); + + // assert + this.orchestrationClient.VerifyAll(); + } + + [Fact] + public async Task TerminateInstance() + { + // arrange + string instanceId = Guid.NewGuid().ToString(); + this.orchestrationClient.Setup(m => m.ForceTerminateTaskOrchestrationAsync(instanceId, null)) + .Returns(Task.CompletedTask); + + // act + await this.client.TerminateInstanceAsync(instanceId, null, default); + + // assert + this.orchestrationClient.VerifyAll(); + } + + [Fact] + public async Task WaitForInstanceCompletion() + { + // arrange + Core.OrchestrationState state = CreateState("input", "output"); + this.orchestrationClient.Setup( + m => m.WaitForOrchestrationAsync(state.OrchestrationInstance.InstanceId, null, TimeSpan.MaxValue, default)) + .ReturnsAsync(state); + + // act + OrchestrationMetadata metadata = await this.client.WaitForInstanceCompletionAsync( + state.OrchestrationInstance.InstanceId, false, default); + + // assert + this.orchestrationClient.VerifyAll(); + Validate(metadata, state, false); + } + + [Fact] + public async Task WaitForInstanceStart() + { + // arrange + DateTimeOffset start = DateTimeOffset.UtcNow; + OrchestrationInstance instance = new() + { + InstanceId = Guid.NewGuid().ToString(), + ExecutionId = Guid.NewGuid().ToString(), + }; + + Core.OrchestrationState state1 = CreateState("input", start: start); + state1.OrchestrationInstance = instance; + state1.OrchestrationStatus = Core.OrchestrationStatus.Pending; + Core.OrchestrationState state2 = CreateState("input", start: start); + state1.OrchestrationInstance = instance; + this.orchestrationClient.SetupSequence(m => m.GetOrchestrationStateAsync(instance.InstanceId, false)) + .ReturnsAsync([state1]) + .ReturnsAsync([state2]); + + // act + OrchestrationMetadata metadata = await this.client.WaitForInstanceStartAsync( + instance.InstanceId, false, default); + + // assert + this.orchestrationClient.Verify( + m => m.GetOrchestrationStateAsync(instance.InstanceId, false), Times.Exactly(2)); + Validate(metadata, state2, false); + } + + [Fact] + public Task ScheduleNewOrchestrationInstance_IdGenerated_NoInput() + => this.RunScheduleNewOrchestrationInstanceAsync("test", null, null); + + [Fact] + public Task ScheduleNewOrchestrationInstance_IdProvided_InputProvided() + => this.RunScheduleNewOrchestrationInstanceAsync("test", "input", new() { InstanceId = "test-id" }); + + [Fact] + public Task ScheduleNewOrchestrationInstance_StartAt() + => this.RunScheduleNewOrchestrationInstanceAsync( + "test", null, new() { StartAt = DateTimeOffset.UtcNow.AddHours(1) }); + + [Fact] + public async Task ScheduleNewOrchestrationInstance_IdProvided_TagsProvided() + { + StartOrchestrationOptions options = new() + { + InstanceId = "test-id", + Tags = new Dictionary + { + { "tag1", "value1" }, + { "tag2", "value2" } + } + }; + await this.RunScheduleNewOrchestrationInstanceAsync("test", "input", options); + } + + + static Core.OrchestrationState CreateState( + object input, object? output = null, DateTimeOffset start = default) + { + string? serializedOutput = null; + FailureDetails? failure = null; + OrchestrationStatus status = OrchestrationStatus.Running; + + if (start == default) + { + start = DateTimeOffset.UtcNow.AddMinutes(-10); + } + + switch (output) + { + case Exception ex: + status = OrchestrationStatus.Failed; + failure = new(ex.GetType().FullName!, ex.Message, ex.StackTrace, null, true); + break; + case not null: + status = OrchestrationStatus.Completed; + serializedOutput = JsonDataConverter.Default.Serialize(output); + break; + } + + return new() + { + CompletedTime = default, + CreatedTime = start.UtcDateTime, + LastUpdatedTime = start.AddMinutes(10).UtcDateTime, + Input = JsonDataConverter.Default.Serialize(input), + Output = serializedOutput, + FailureDetails = failure, + Name = "test-orchestration", + OrchestrationInstance = new() + { + InstanceId = Guid.NewGuid().ToString(), + ExecutionId = Guid.NewGuid().ToString(), + }, + OrchestrationStatus = status, + Status = JsonDataConverter.Default.Serialize("custom-status"), + Version = string.Empty, + }; + } + + static TaskMessage MatchStartExecutionMessage(TaskName name, object? input, StartOrchestrationOptions? options) + { + return Match.Create(m => + { + if (m.Event is not ExecutionStartedEvent @event) + { + return false; + } + + + if (options?.InstanceId is string str && m.OrchestrationInstance.InstanceId != str) + { + return false; + } + else if (options?.InstanceId is null && !Guid.TryParse(m.OrchestrationInstance.InstanceId, out _)) + { + return false; + } + + if (options?.StartAt is DateTimeOffset start && @event.ScheduledStartTime != start.UtcDateTime) + { + return false; + } + else if (options?.StartAt is null && @event.ScheduledStartTime is not null) + { + return false; + } + + return Guid.TryParse(m.OrchestrationInstance.ExecutionId, out _) + && @event.Name == name.Name && @event.Version == name.Version + && @event.OrchestrationInstance == m.OrchestrationInstance + && @event.EventId == -1 + && @event.Input == JsonDataConverter.Default.Serialize(input); + }); + } + + static void Validate(OrchestrationMetadata? metadata, Core.OrchestrationState? state, bool getInputs) + { + if (state is null) + { + metadata.Should().BeNull(); + return; + } + + metadata.Should().NotBeNull(); + metadata!.Name.Should().Be(state.Name); + metadata.InstanceId.Should().Be(state.OrchestrationInstance.InstanceId); + metadata.RuntimeStatus.Should().Be(state.OrchestrationStatus.ConvertFromCore()); + metadata.CreatedAt.Should().Be(new DateTimeOffset(state.CreatedTime)); + metadata.LastUpdatedAt.Should().Be(new DateTimeOffset(state.LastUpdatedTime)); + metadata.SerializedInput.Should().Be(state.Input); + metadata.SerializedOutput.Should().Be(state.Output); + metadata.SerializedCustomStatus.Should().Be(state.Status); + + if (getInputs) + { + metadata.Invoking(m => m.ReadInputAs()).Should().NotThrow(); + } + } + + static void Validate(TaskFailureDetails? left, FailureDetails? right) + { + if (right is null) + { + left.Should().BeNull(); + return; + } + + left.Should().NotBeNull(); + left!.ErrorType.Should().Be(right.ErrorType); + left.ErrorMessage.Should().Be(right.ErrorMessage); + left.StackTrace.Should().Be(right.StackTrace); + Validate(left.InnerFailure, right.InnerFailure); + } + + void SetupClientTaskMessage(string instanceId) + where TEvent : HistoryEvent + { + this.orchestrationClient + .Setup(m => m.SendTaskOrchestrationMessageAsync(It.Is(m => + m.OrchestrationInstance.InstanceId == instanceId && m.Event.GetType() == typeof(TEvent)) + )) + .Returns(Task.CompletedTask); + } + + async Task RunScheduleNewOrchestrationInstanceAsync( + TaskName name, object? input, StartOrchestrationOptions? options) + { + // arrange + this.orchestrationClient.Setup( + m => m.CreateTaskOrchestrationAsync(MatchStartExecutionMessage(name, input, options))) + .Returns(Task.CompletedTask); + + // act + string instanceId = await this.client.ScheduleNewOrchestrationInstanceAsync(name, input, options, default); + + // assert + this.orchestrationClient.Verify( + m => m.CreateTaskOrchestrationAsync(MatchStartExecutionMessage(name, input, options)), + Times.Once()); + + if (options?.InstanceId is string str) + { + instanceId.Should().Be(str); + } + else + { + Guid.TryParse(instanceId, out _).Should().BeTrue(); + } + } } \ No newline at end of file diff --git a/test/Grpc.IntegrationTests/Grpc.IntegrationTests.csproj b/test/Grpc.IntegrationTests/Grpc.IntegrationTests.csproj index 2669b3730..e6b0aee76 100644 --- a/test/Grpc.IntegrationTests/Grpc.IntegrationTests.csproj +++ b/test/Grpc.IntegrationTests/Grpc.IntegrationTests.csproj @@ -11,7 +11,6 @@ - diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/AsyncManualResetEvent.cs b/test/Grpc.IntegrationTests/GrpcSidecar/AsyncManualResetEvent.cs new file mode 100644 index 000000000..b7cf6dc4c --- /dev/null +++ b/test/Grpc.IntegrationTests/GrpcSidecar/AsyncManualResetEvent.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Sidecar; + +class AsyncManualResetEvent +{ + readonly object mutex = new(); + TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + + public AsyncManualResetEvent(bool isSignaled) + { + if (isSignaled) + { + this.tcs.TrySetCanceled(); + } + } + + public async Task WaitAsync(TimeSpan timeout, CancellationToken cancellationToken) + { + Task delayTask = Task.Delay(timeout, cancellationToken); + Task waitTask = this.tcs.Task; + + Task winner = await Task.WhenAny(waitTask, delayTask); + + // Await ensures we get a TaskCancelledException if there was a cancellation. + await winner; + + return winner == waitTask; + } + + public bool IsSignaled => this.tcs.Task.IsCompleted; + + /// + /// Puts the event in the signaled state, unblocking any waiting threads. + /// + public bool Set() + { + lock (this.mutex) + { + return this.tcs.TrySetResult(); + } + } + + /// + /// Puts this event into the unsignaled state, causing threads to block. + /// + public void Reset() + { + lock (this.mutex) + { + if (this.tcs.Task.IsCompleted) + { + this.tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + } + } +} diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/ITaskExecutor.cs b/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/ITaskExecutor.cs new file mode 100644 index 000000000..535b95021 --- /dev/null +++ b/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/ITaskExecutor.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core; +using DurableTask.Core.History; + +namespace Microsoft.DurableTask.Sidecar.Dispatcher; + +interface ITaskExecutor +{ + /// + /// When implemented by a concrete type, executes an orchestrator and returns the next set of orchestrator actions. + /// + /// The instance ID information of the orchestrator to execute. + /// The history events for previous executions of this orchestration instance. + /// The history events that have not yet been executed by this orchestration instance. + /// + /// Returns a task containing the result of the orchestrator execution. These are effectively the side-effects of the + /// orchestrator code, such as calling activities, scheduling timers, etc. + /// + Task ExecuteOrchestrator( + OrchestrationInstance instance, + IEnumerable pastEvents, + IEnumerable newEvents); + + /// + /// When implemented by a concreate type, executes an activity task and returns its results. + /// + /// The instance ID information of the orchestration that scheduled this activity task. + /// The metadata of the activity task execution, including the activity name and input. + /// Returns a task that contains the execution result of the activity. + Task ExecuteActivity( + OrchestrationInstance instance, + TaskScheduledEvent activityEvent); +} diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/ITrafficSignal.cs b/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/ITrafficSignal.cs new file mode 100644 index 000000000..fa6a092b7 --- /dev/null +++ b/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/ITrafficSignal.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Sidecar.Dispatcher; + +/// +/// A simple primitive that can be used to block logical threads until some condition occurs. +/// +interface ITrafficSignal +{ + /// + /// Provides a human-friendly reason for why the signal is in the "wait" state. + /// + string WaitReason { get; } + + /// + /// Blocks the caller until the method is called. + /// + /// The amount of time to wait until the signal is unblocked. + /// A cancellation token that can be used to interrupt a waiting caller. + /// + /// Returns true if the traffic signal is all-clear; false if we timed-out waiting for the signal to clear. + /// + /// + /// Thrown if is triggered while waiting. + /// + Task WaitAsync(TimeSpan waitTime, CancellationToken cancellationToken); +} + diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/TaskActivityDispatcher.cs b/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/TaskActivityDispatcher.cs new file mode 100644 index 000000000..282d6a9c9 --- /dev/null +++ b/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/TaskActivityDispatcher.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core; +using DurableTask.Core.History; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DurableTask.Sidecar.Dispatcher; + +class TaskActivityDispatcher : WorkItemDispatcher +{ + readonly IOrchestrationService service; + readonly ITaskExecutor taskExecutor; + + public TaskActivityDispatcher(ILogger log, ITrafficSignal trafficSignal, IOrchestrationService service, ITaskExecutor taskExecutor) + : base(log, trafficSignal) + { + this.service = service; + this.taskExecutor = taskExecutor; + } + + public override int MaxWorkItems => this.service.MaxConcurrentTaskActivityWorkItems; + + public override Task AbandonWorkItemAsync(TaskActivityWorkItem workItem) => + this.service.AbandonTaskActivityWorkItemAsync(workItem); + + public override Task FetchWorkItemAsync(TimeSpan timeout, CancellationToken cancellationToken) => + this.service.LockNextTaskActivityWorkItem(timeout, cancellationToken); + + protected override async Task ExecuteWorkItemAsync(TaskActivityWorkItem workItem) + { + TaskScheduledEvent scheduledEvent = (TaskScheduledEvent)workItem.TaskMessage.Event; + + // TODO: Error handling for internal errors (user code exceptions are handled by the executor). + ActivityExecutionResult result = await this.taskExecutor.ExecuteActivity( + instance: workItem.TaskMessage.OrchestrationInstance, + activityEvent: scheduledEvent); + + TaskMessage responseMessage = new() + { + Event = result.ResponseEvent, + OrchestrationInstance = workItem.TaskMessage.OrchestrationInstance, + }; + + await this.service.CompleteTaskActivityWorkItemAsync(workItem, responseMessage); + } + + public override int GetDelayInSecondsOnFetchException(Exception ex) => + this.service.GetDelayInSecondsAfterOnFetchException(ex); + + public override string GetWorkItemId(TaskActivityWorkItem workItem) => workItem.Id; + + // No-op + public override Task ReleaseWorkItemAsync(TaskActivityWorkItem workItem) => Task.CompletedTask; + + public override Task RenewWorkItemAsync(TaskActivityWorkItem workItem) => + this.service.RenewTaskActivityWorkItemLockAsync(workItem); +} \ No newline at end of file diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/TaskHubDispatcherHost.cs b/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/TaskHubDispatcherHost.cs new file mode 100644 index 000000000..89a54d025 --- /dev/null +++ b/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/TaskHubDispatcherHost.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DurableTask.Sidecar.Dispatcher; + +class TaskHubDispatcherHost +{ + readonly TaskOrchestrationDispatcher orchestrationDispatcher; + readonly TaskActivityDispatcher activityDispatcher; + + readonly IOrchestrationService orchestrationService; + readonly ILogger log; + + public TaskHubDispatcherHost( + ILoggerFactory loggerFactory, + ITrafficSignal trafficSignal, + IOrchestrationService orchestrationService, + ITaskExecutor taskExecutor) + { + this.orchestrationService = orchestrationService ?? throw new ArgumentNullException(nameof(orchestrationService)); + this.log = loggerFactory.CreateLogger("Microsoft.DurableTask.Sidecar"); + + this.orchestrationDispatcher = new TaskOrchestrationDispatcher(log, trafficSignal, orchestrationService, taskExecutor); + this.activityDispatcher = new TaskActivityDispatcher(log, trafficSignal, orchestrationService, taskExecutor); + } + + public async Task StartAsync(CancellationToken cancellationToken) + { + // Start any background processing in the orchestration service + await this.orchestrationService.StartAsync(); + + // Start the dispatchers, which will allow orchestrations/activities to execute + await Task.WhenAll( + this.orchestrationDispatcher.StartAsync(cancellationToken), + this.activityDispatcher.StartAsync(cancellationToken)); + } + + public async Task StopAsync(CancellationToken cancellationToken) + { + // Stop the dispatchers from polling the orchestration service + await Task.WhenAll( + this.orchestrationDispatcher.StopAsync(cancellationToken), + this.activityDispatcher.StopAsync(cancellationToken)); + + // Tell the storage provider to stop doing any background work. + await this.orchestrationService.StopAsync(); + } +} diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/TaskOrchestrationDispatcher.cs b/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/TaskOrchestrationDispatcher.cs new file mode 100644 index 000000000..20cae03e3 --- /dev/null +++ b/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/TaskOrchestrationDispatcher.cs @@ -0,0 +1,433 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text; +using DurableTask.Core; +using DurableTask.Core.Command; +using DurableTask.Core.History; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DurableTask.Sidecar.Dispatcher; + +class TaskOrchestrationDispatcher : WorkItemDispatcher +{ + readonly ILogger log; + readonly IOrchestrationService service; + readonly ITaskExecutor taskExecutor; + + public TaskOrchestrationDispatcher(ILogger log, ITrafficSignal trafficSignal, IOrchestrationService service, ITaskExecutor taskExecutor) + : base(log, trafficSignal) + { + this.log = log; + this.service = service; + this.taskExecutor = taskExecutor; + } + + public override int MaxWorkItems => this.service.MaxConcurrentTaskOrchestrationWorkItems; + + public override Task AbandonWorkItemAsync(TaskOrchestrationWorkItem workItem) => + this.service.AbandonTaskOrchestrationWorkItemAsync(workItem); + + public override Task FetchWorkItemAsync(TimeSpan timeout, CancellationToken cancellationToken) => + this.service.LockNextTaskOrchestrationWorkItemAsync(timeout, cancellationToken); + + protected override async Task ExecuteWorkItemAsync(TaskOrchestrationWorkItem workItem) + { + // Convert the new messages into new history events + workItem.OrchestrationRuntimeState.AddEvent(new OrchestratorStartedEvent(-1)); + foreach (TaskMessage message in FilterAndSortMessages(workItem)) + { + workItem.OrchestrationRuntimeState.AddEvent(message.Event); + } + + OrchestrationInstance? instance = workItem.OrchestrationRuntimeState.OrchestrationInstance; + if (string.IsNullOrEmpty(instance?.InstanceId)) + { + throw new ArgumentException($"Could not find an orchestration instance ID in the work item's runtime state.", nameof(workItem)); + } + + // We loop for as long as the orchestrator does a ContinueAsNew + while (true) + { + if (this.log.IsEnabled(LogLevel.Debug)) + { + IList newEvents = workItem.OrchestrationRuntimeState.NewEvents; + string newEventSummary = GetEventSummaryForLogging(newEvents); + this.log.OrchestratorExecuting( + workItem.InstanceId, + workItem.OrchestrationRuntimeState.Name, + newEvents.Count, + newEventSummary); + } + + // Execute the orchestrator code and get back a set of new actions to take. + // IMPORTANT: This IEnumerable may be lazily evaluated and should only be enumerated once! + OrchestratorExecutionResult result = await this.taskExecutor.ExecuteOrchestrator( + instance, + workItem.OrchestrationRuntimeState.PastEvents, + workItem.OrchestrationRuntimeState.NewEvents); + + // Convert the actions into history events and messages. + // If the actions result in a continue-as-new state, + this.ApplyOrchestratorActions( + result, + ref workItem.OrchestrationRuntimeState, + out IList activityMessages, + out IList orchestratorMessages, + out IList timerMessages, + out OrchestrationState? updatedStatus, + out bool continueAsNew); + if (continueAsNew) + { + // Continue running the orchestration with a new history. + // Renew the lock if we're getting close to its expiration. + if (workItem.LockedUntilUtc != default && DateTime.UtcNow.AddMinutes(1) > workItem.LockedUntilUtc) + { + await this.service.RenewTaskOrchestrationWorkItemLockAsync(workItem); + } + + continue; + } + + // Commit the changes to the durable store + await this.service.CompleteTaskOrchestrationWorkItemAsync( + workItem, + workItem.OrchestrationRuntimeState, + activityMessages, + orchestratorMessages, + timerMessages, + continuedAsNewMessage: null /* not supported */, + updatedStatus); + + break; + } + } + + static string GetEventSummaryForLogging(IList actions) + { + if (actions.Count == 0) + { + return string.Empty; + } + else if (actions.Count == 1) + { + return actions[0].EventType.ToString(); + } + else + { + // Returns something like "TaskCompleted x5, TimerFired x1,..." + return string.Join(", ", actions + .GroupBy(a => a.EventType) + .Select(group => $"{group.Key} x{group.Count()}")); + } + } + + static IEnumerable FilterAndSortMessages(TaskOrchestrationWorkItem workItem) + { + // Group messages by their instance ID + static string GetGroupingKey(TaskMessage msg) => msg.OrchestrationInstance.InstanceId; + + // Within a group, put messages with a non-null execution ID first + static int GetSortOrderWithinGroup(TaskMessage msg) + { + if (msg.Event.EventType == EventType.ExecutionStarted) + { + // Prioritize ExecutionStarted messages + return 0; + } + else if (msg.OrchestrationInstance.ExecutionId != null) + { + // Prioritize messages with an execution ID + return 1; + } + else + { + return 2; + } + } + + string? executionId = workItem.OrchestrationRuntimeState?.OrchestrationInstance?.ExecutionId; + + foreach (var group in workItem.NewMessages.GroupBy(GetGroupingKey)) + { + // TODO: Filter out invalid messages (wrong execution ID, duplicate start/complete messages, etc.) + foreach (TaskMessage msg in group.OrderBy(GetSortOrderWithinGroup)) + { + yield return msg; + } + } + } + + void ApplyOrchestratorActions( + OrchestratorExecutionResult result, + ref OrchestrationRuntimeState runtimeState, + out IList activityMessages, + out IList orchestratorMessages, + out IList timerMessages, + out OrchestrationState? updatedStatus, + out bool continueAsNew) + { + if (string.IsNullOrEmpty(runtimeState.OrchestrationInstance?.InstanceId)) + { + throw new ArgumentException($"The provided {nameof(OrchestrationRuntimeState)} doesn't contain an instance ID!", nameof(runtimeState)); + } + + IList? newActivityMessages = null; + IList? newTimerMessages = null; + IList? newOrchestratorMessages = null; + FailureDetails? failureDetails = null; + continueAsNew = false; + + runtimeState.Status = result.CustomStatus; + + foreach (OrchestratorAction action in result.Actions) + { + // TODO: Determine how to handle remaining actions if the instance completed with ContinueAsNew. + // TODO: Validate each of these actions to make sure they have the appropriate data. + if (action is ScheduleTaskOrchestratorAction scheduleTaskAction) + { + if (string.IsNullOrEmpty(scheduleTaskAction.Name)) + { + throw new ArgumentException($"The provided {nameof(ScheduleTaskOrchestratorAction)} has no Name property specified!", nameof(result)); + } + + TaskScheduledEvent scheduledEvent = new( + scheduleTaskAction.Id, + scheduleTaskAction.Name, + scheduleTaskAction.Version, + scheduleTaskAction.Input); + + newActivityMessages ??= new List(); + newActivityMessages.Add(new TaskMessage + { + Event = scheduledEvent, + OrchestrationInstance = runtimeState.OrchestrationInstance, + }); + + runtimeState.AddEvent(scheduledEvent); + } + else if (action is CreateTimerOrchestratorAction timerAction) + { + TimerCreatedEvent timerEvent = new(timerAction.Id, timerAction.FireAt); + + newTimerMessages ??= new List(); + newTimerMessages.Add(new TaskMessage + { + Event = new TimerFiredEvent(-1, timerAction.FireAt) + { + TimerId = timerAction.Id, + }, + OrchestrationInstance = runtimeState.OrchestrationInstance, + }); + + runtimeState.AddEvent(timerEvent); + } + else if (action is CreateSubOrchestrationAction subOrchestrationAction) + { + runtimeState.AddEvent(new SubOrchestrationInstanceCreatedEvent(subOrchestrationAction.Id) + { + Name = subOrchestrationAction.Name, + Version = subOrchestrationAction.Version, + InstanceId = subOrchestrationAction.InstanceId, + Input = subOrchestrationAction.Input, + }); + + ExecutionStartedEvent startedEvent = new(-1, subOrchestrationAction.Input) + { + Name = subOrchestrationAction.Name, + Version = subOrchestrationAction.Version, + OrchestrationInstance = new OrchestrationInstance + { + InstanceId = subOrchestrationAction.InstanceId, + ExecutionId = Guid.NewGuid().ToString("N"), + }, + ParentInstance = new ParentInstance + { + OrchestrationInstance = runtimeState.OrchestrationInstance, + Name = runtimeState.Name, + Version = runtimeState.Version, + TaskScheduleId = subOrchestrationAction.Id, + }, + Tags = subOrchestrationAction.Tags, + }; + + newOrchestratorMessages ??= new List(); + newOrchestratorMessages.Add(new TaskMessage + { + Event = startedEvent, + OrchestrationInstance = startedEvent.OrchestrationInstance, + }); + } + else if (action is SendEventOrchestratorAction sendEventAction) + { + if (string.IsNullOrEmpty(sendEventAction.Instance?.InstanceId)) + { + throw new ArgumentException($"The provided {nameof(SendEventOrchestratorAction)} doesn't contain an instance ID!"); + } + + EventSentEvent sendEvent = new(sendEventAction.Id) + { + InstanceId = sendEventAction.Instance.InstanceId, + Name = sendEventAction.EventName, + Input = sendEventAction.EventData, + }; + + runtimeState.AddEvent(sendEvent); + + newOrchestratorMessages ??= new List(); + newOrchestratorMessages.Add(new TaskMessage + { + Event = sendEvent, + OrchestrationInstance = runtimeState.OrchestrationInstance, + }); + } + else if (action is OrchestrationCompleteOrchestratorAction completeAction) + { + if (completeAction.OrchestrationStatus == OrchestrationStatus.ContinuedAsNew) + { + // Replace the existing runtime state with a complete new runtime state. + OrchestrationRuntimeState newRuntimeState = new(); + newRuntimeState.AddEvent(new OrchestratorStartedEvent(-1)); + newRuntimeState.AddEvent(new ExecutionStartedEvent(-1, completeAction.Result) + { + OrchestrationInstance = new OrchestrationInstance + { + InstanceId = runtimeState.OrchestrationInstance.InstanceId, + ExecutionId = Guid.NewGuid().ToString("N"), + }, + Tags = runtimeState.Tags, + ParentInstance = runtimeState.ParentInstance, + Name = runtimeState.Name, + Version = completeAction.NewVersion ?? runtimeState.Version + }); + newRuntimeState.Status = runtimeState.Status; + + // The orchestration may have completed with some pending events that need to be carried + // over to the new generation, such as unprocessed external event messages. + if (completeAction.CarryoverEvents != null) + { + foreach (HistoryEvent carryoverEvent in completeAction.CarryoverEvents) + { + newRuntimeState.AddEvent(carryoverEvent); + } + } + + runtimeState = newRuntimeState; + activityMessages = Array.Empty(); + orchestratorMessages = Array.Empty(); + timerMessages = Array.Empty(); + continueAsNew = true; + updatedStatus = null; + return; + } + else + { + this.log.OrchestratorCompleted( + runtimeState.OrchestrationInstance.InstanceId, + runtimeState.Name, + completeAction.OrchestrationStatus, + Encoding.UTF8.GetByteCount(completeAction.Result ?? string.Empty)); + } + + if (completeAction.OrchestrationStatus == OrchestrationStatus.Failed) + { + failureDetails = completeAction.FailureDetails; + } + + // NOTE: Failure details aren't being stored in the orchestration history, currently. + runtimeState.AddEvent(new ExecutionCompletedEvent( + completeAction.Id, + completeAction.Result, + completeAction.OrchestrationStatus)); + + // CONSIDER: Add support for fire-and-forget sub-orchestrations where + // we don't notify the parent that the orchestration completed. + if (runtimeState.ParentInstance != null) + { + HistoryEvent subOrchestratorCompletedEvent; + if (completeAction.OrchestrationStatus == OrchestrationStatus.Completed) + { + subOrchestratorCompletedEvent = new SubOrchestrationInstanceCompletedEvent( + eventId: -1, + runtimeState.ParentInstance.TaskScheduleId, + completeAction.Result); + } + else + { + subOrchestratorCompletedEvent = new SubOrchestrationInstanceFailedEvent( + eventId: -1, + runtimeState.ParentInstance.TaskScheduleId, + completeAction.Result, + completeAction.Details, + completeAction.FailureDetails); + } + + newOrchestratorMessages ??= new List(); + newOrchestratorMessages.Add(new TaskMessage + { + Event = subOrchestratorCompletedEvent, + OrchestrationInstance = runtimeState.ParentInstance.OrchestrationInstance, + }); + } + } + else + { + this.log.IgnoringUnknownOrchestratorAction( + runtimeState.OrchestrationInstance.InstanceId, + action.OrchestratorActionType); + } + } + + runtimeState.AddEvent(new OrchestratorCompletedEvent(-1)); + + activityMessages = newActivityMessages ?? Array.Empty(); + timerMessages = newTimerMessages ?? Array.Empty(); + orchestratorMessages = newOrchestratorMessages ?? Array.Empty(); + + updatedStatus = new OrchestrationState + { + OrchestrationInstance = runtimeState.OrchestrationInstance, + ParentInstance = runtimeState.ParentInstance, + Name = runtimeState.Name, + Version = runtimeState.Version, + Status = runtimeState.Status, + Tags = runtimeState.Tags, + OrchestrationStatus = runtimeState.OrchestrationStatus, + CreatedTime = runtimeState.CreatedTime, + CompletedTime = runtimeState.CompletedTime, + LastUpdatedTime = DateTime.UtcNow, + Size = runtimeState.Size, + CompressedSize = runtimeState.CompressedSize, + Input = runtimeState.Input, + Output = runtimeState.Output, + ScheduledStartTime = runtimeState.ExecutionStartedEvent?.ScheduledStartTime, + FailureDetails = failureDetails, + }; + } + + static string GetShortHistoryEventDescription(HistoryEvent e) + { + if (Utils.TryGetTaskScheduledId(e, out int taskScheduledId)) + { + return $"{e.EventType}#{taskScheduledId}"; + } + else + { + return e.EventType.ToString(); + } + } + + public override int GetDelayInSecondsOnFetchException(Exception ex) => + this.service.GetDelayInSecondsAfterOnFetchException(ex); + + public override string GetWorkItemId(TaskOrchestrationWorkItem workItem) => workItem.InstanceId; + + public override Task ReleaseWorkItemAsync(TaskOrchestrationWorkItem workItem) => + this.service.ReleaseTaskOrchestrationWorkItemAsync(workItem); + + public override async Task RenewWorkItemAsync(TaskOrchestrationWorkItem workItem) + { + await this.service.RenewTaskOrchestrationWorkItemLockAsync(workItem); + return workItem; + } +} diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/WorkItemDispatcher.cs b/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/WorkItemDispatcher.cs new file mode 100644 index 000000000..3a9039bb0 --- /dev/null +++ b/test/Grpc.IntegrationTests/GrpcSidecar/Dispatcher/WorkItemDispatcher.cs @@ -0,0 +1,259 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DurableTask.Sidecar.Dispatcher; + +abstract class WorkItemDispatcher where T : class +{ + static int nextDispatcherId = 0; + + readonly string name; + readonly ILogger log; + readonly ITrafficSignal trafficSignal; + + CancellationTokenSource? shutdownTcs; + Task? workItemExecuteLoop; + int currentWorkItems; + + public WorkItemDispatcher(ILogger log, ITrafficSignal trafficSignal) + { + this.log = log ?? throw new ArgumentNullException(nameof(log)); + this.trafficSignal = trafficSignal; + + this.name = $"{this.GetType().Name}-{Interlocked.Increment(ref nextDispatcherId)}"; + } + + public virtual int MaxWorkItems => 10; + + public abstract Task FetchWorkItemAsync(TimeSpan timeout, CancellationToken cancellationToken); + + protected abstract Task ExecuteWorkItemAsync(T workItem); + + public abstract Task ReleaseWorkItemAsync(T workItem); + + public abstract Task AbandonWorkItemAsync(T workItem); + + public abstract Task RenewWorkItemAsync(T workItem); + + public abstract string GetWorkItemId(T workItem); + + public abstract int GetDelayInSecondsOnFetchException(Exception ex); + + public virtual Task StartAsync(CancellationToken cancellationToken) + { + // Dispatchers can be stopped and started back up again + this.shutdownTcs?.Dispose(); + this.shutdownTcs = new CancellationTokenSource(); + + this.workItemExecuteLoop = Task.Run( + () => this.FetchAndExecuteLoop(this.shutdownTcs.Token), + CancellationToken.None); + + return Task.CompletedTask; + } + + public virtual async Task StopAsync(CancellationToken cancellationToken) + { + // Trigger the cancellation tokens being used for background processing. + this.shutdownTcs?.Cancel(); + + // Wait for the execution loop to complete to ensure we're not scheduling any new work + Task? executeLoop = this.workItemExecuteLoop; + if (executeLoop != null) + { + await executeLoop.WaitAsync(cancellationToken); + } + + // Wait for all outstanding work-item processing to complete for a fully graceful shutdown + await this.WaitForOutstandingWorkItems(cancellationToken); + } + + async Task WaitForAllClear(CancellationToken cancellationToken) + { + TimeSpan logInterval = TimeSpan.FromMinutes(1); + + // IMPORTANT: This logic assumes only a single logical "thread" is executing the receive loop, + // and that there's no possible race condition when comparing work-item counts. + DateTime nextLogTime = DateTime.MinValue; + while (this.currentWorkItems >= this.MaxWorkItems) + { + // Periodically log that we're waiting for available concurrency. + // No need to use UTC for this. Local time is a bit easier to debug. + DateTime now = DateTime.Now; + if (now >= nextLogTime) + { + this.log.FetchingThrottled( + dispatcher: this.name, + details: "The current active work-item count has reached the allowed maximum.", + this.currentWorkItems, + this.MaxWorkItems); + nextLogTime = now.Add(logInterval); + } + + // CONSIDER: Use a notification instead of polling. + await Task.Delay(TimeSpan.FromMilliseconds(500), cancellationToken); + } + + // The dispatcher can also be paused by external signals. + while (!await this.trafficSignal.WaitAsync(logInterval, cancellationToken)) + { + this.log.FetchingThrottled( + dispatcher: this.name, + details: this.trafficSignal.WaitReason, + this.currentWorkItems, + this.MaxWorkItems); + } + } + + async Task WaitForOutstandingWorkItems(CancellationToken cancellationToken) + { + DateTime nextLogTime = DateTime.MinValue; + while (this.currentWorkItems > 0) + { + // Periodically log that we're waiting for outstanding work items to complete. + // No need to use UTC for this. Local time is a bit easier to debug. + DateTime now = DateTime.Now; + if (now >= nextLogTime) + { + this.log.DispatcherStopping(this.name, this.currentWorkItems); + nextLogTime = now.AddMinutes(1); + } + + // CONSIDER: Use a notification instead of polling. + await Task.Delay(TimeSpan.FromMilliseconds(200), cancellationToken); + } + } + + // This method does not throw + async Task DelayOnException( + Exception exception, + string workItemId, + CancellationToken cancellationToken, + Func delayInSecondsPolicy) + { + try + { + int delaySeconds = delayInSecondsPolicy(exception); + if (delaySeconds > 0) + { + await Task.Delay(delaySeconds, cancellationToken); + } + } + catch (OperationCanceledException) + { + // Shutting down, do nothing + } + catch (Exception ex) + { + this.log.DispatchWorkItemFailure( + dispatcher: this.name, + action: "delay-on-exception", + workItemId, + details: ex.ToString()); + try + { + await Task.Delay(TimeSpan.FromSeconds(1), cancellationToken); + } + catch (OperationCanceledException) + { + // shutting down + } + } + } + + async Task FetchAndExecuteLoop(CancellationToken cancellationToken) + { + try + { + // The work-item receive loop feeds the execution loop + while (true) + { + T? workItem = null; + try + { + await this.WaitForAllClear(cancellationToken); + + this.log.FetchWorkItemStarting(this.name, this.currentWorkItems, this.MaxWorkItems); + Stopwatch sw = Stopwatch.StartNew(); + + workItem = await this.FetchWorkItemAsync(Timeout.InfiniteTimeSpan, cancellationToken); + + if (workItem != null) + { + this.currentWorkItems++; + this.log.FetchWorkItemCompleted( + this.name, + this.GetWorkItemId(workItem), + sw.ElapsedMilliseconds, + this.currentWorkItems, + this.MaxWorkItems); + + // Run the execution on a background thread, which must never be canceled. + _ = Task.Run(() => this.ExecuteWorkItem(workItem), CancellationToken.None); + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // shutting down + break; + } + catch (Exception ex) + { + string unknownWorkItemId = "(unknown)"; + this.log.DispatchWorkItemFailure( + dispatcher: this.name, + action: "fetchWorkItem", + workItemId: unknownWorkItemId, + details: ex.ToString()); + await this.DelayOnException(ex, unknownWorkItemId, cancellationToken, this.GetDelayInSecondsOnFetchException); + continue; + } + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // graceful shutdown + } + } + + async Task ExecuteWorkItem(T workItem) + { + try + { + // Execute the work item and wait for it to complete + await this.ExecuteWorkItemAsync(workItem); + } + catch (Exception ex) + { + this.log.DispatchWorkItemFailure( + dispatcher: this.name, + action: "execute", + workItemId: this.GetWorkItemId(workItem), + details: ex.ToString()); + + await this.AbandonWorkItemAsync(workItem); + } + finally + { + try + { + await this.ReleaseWorkItemAsync(workItem); + } + catch (Exception ex) + { + // Best effort + this.log.DispatchWorkItemFailure( + dispatcher: this.name, + action: "release", + workItemId: this.GetWorkItemId(workItem), + details: ex.ToString()); + } + + this.currentWorkItems--; + } + } +} + diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/Grpc/ProtobufUtils.cs b/test/Grpc.IntegrationTests/GrpcSidecar/Grpc/ProtobufUtils.cs new file mode 100644 index 000000000..6c9ab0b95 --- /dev/null +++ b/test/Grpc.IntegrationTests/GrpcSidecar/Grpc/ProtobufUtils.cs @@ -0,0 +1,423 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Buffers; +using DurableTask.Core; +using DurableTask.Core.Command; +using DurableTask.Core.History; +using DurableTask.Core.Query; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using Proto = Microsoft.DurableTask.Protobuf; + +namespace Microsoft.DurableTask.Sidecar.Grpc; + +public static class ProtobufUtils +{ + public static Proto.HistoryEvent ToHistoryEventProto(HistoryEvent e) + { + var payload = new Proto.HistoryEvent() + { + EventId = e.EventId, + Timestamp = Timestamp.FromDateTime(e.Timestamp), + }; + + switch (e.EventType) + { + case EventType.ContinueAsNew: + var continueAsNew = (ContinueAsNewEvent)e; + payload.ContinueAsNew = new Proto.ContinueAsNewEvent + { + Input = continueAsNew.Result, + }; + break; + case EventType.EventRaised: + var eventRaised = (EventRaisedEvent)e; + payload.EventRaised = new Proto.EventRaisedEvent + { + Name = eventRaised.Name, + Input = eventRaised.Input, + }; + break; + case EventType.EventSent: + var eventSent = (EventSentEvent)e; + payload.EventSent = new Proto.EventSentEvent + { + Name = eventSent.Name, + Input = eventSent.Input, + InstanceId = eventSent.InstanceId, + }; + break; + case EventType.ExecutionCompleted: + var completedEvent = (ExecutionCompletedEvent)e; + payload.ExecutionCompleted = new Proto.ExecutionCompletedEvent + { + OrchestrationStatus = Proto.OrchestrationStatus.Completed, + Result = completedEvent.Result, + }; + break; + case EventType.ExecutionFailed: + var failedEvent = (ExecutionCompletedEvent)e; + payload.ExecutionCompleted = new Proto.ExecutionCompletedEvent + { + OrchestrationStatus = Proto.OrchestrationStatus.Failed, + Result = failedEvent.Result, + }; + break; + case EventType.ExecutionStarted: + // Start of a new orchestration instance + var startedEvent = (ExecutionStartedEvent)e; + startedEvent.Tags ??= new Dictionary(); + payload.ExecutionStarted = new Proto.ExecutionStartedEvent + { + Name = startedEvent.Name, + Version = startedEvent.Version, + Input = startedEvent.Input, + Tags = { startedEvent.Tags }, + OrchestrationInstance = new Proto.OrchestrationInstance + { + InstanceId = startedEvent.OrchestrationInstance.InstanceId, + ExecutionId = startedEvent.OrchestrationInstance.ExecutionId, + }, + ParentInstance = startedEvent.ParentInstance == null ? null : new Proto.ParentInstanceInfo + { + Name = startedEvent.ParentInstance.Name, + Version = startedEvent.ParentInstance.Version, + TaskScheduledId = startedEvent.ParentInstance.TaskScheduleId, + OrchestrationInstance = new Proto.OrchestrationInstance + { + InstanceId = startedEvent.ParentInstance.OrchestrationInstance.InstanceId, + ExecutionId = startedEvent.ParentInstance.OrchestrationInstance.ExecutionId, + }, + }, + ScheduledStartTimestamp = startedEvent.ScheduledStartTime == null ? null : Timestamp.FromDateTime(startedEvent.ScheduledStartTime.Value), + ParentTraceContext = startedEvent.ParentTraceContext is null ? null : new Proto.TraceContext + { + TraceParent = startedEvent.ParentTraceContext.TraceParent, + TraceState = startedEvent.ParentTraceContext.TraceState, + }, + }; + break; + case EventType.ExecutionTerminated: + var terminatedEvent = (ExecutionTerminatedEvent)e; + payload.ExecutionTerminated = new Proto.ExecutionTerminatedEvent + { + Input = terminatedEvent.Input, + }; + break; + case EventType.TaskScheduled: + var taskScheduledEvent = (TaskScheduledEvent)e; + payload.TaskScheduled = new Proto.TaskScheduledEvent + { + Name = taskScheduledEvent.Name, + Version = taskScheduledEvent.Version, + Input = taskScheduledEvent.Input, + ParentTraceContext = taskScheduledEvent.ParentTraceContext is null ? null : new Proto.TraceContext + { + TraceParent = taskScheduledEvent.ParentTraceContext.TraceParent, + TraceState = taskScheduledEvent.ParentTraceContext.TraceState, + }, + }; + break; + case EventType.TaskCompleted: + var taskCompletedEvent = (TaskCompletedEvent)e; + payload.TaskCompleted = new Proto.TaskCompletedEvent + { + Result = taskCompletedEvent.Result, + TaskScheduledId = taskCompletedEvent.TaskScheduledId, + }; + break; + case EventType.TaskFailed: + var taskFailedEvent = (TaskFailedEvent)e; + payload.TaskFailed = new Proto.TaskFailedEvent + { + FailureDetails = GetFailureDetails(taskFailedEvent.FailureDetails), + TaskScheduledId = taskFailedEvent.TaskScheduledId, + }; + break; + case EventType.SubOrchestrationInstanceCreated: + var subOrchestrationCreated = (SubOrchestrationInstanceCreatedEvent)e; + payload.SubOrchestrationInstanceCreated = new Proto.SubOrchestrationInstanceCreatedEvent + { + Input = subOrchestrationCreated.Input, + InstanceId = subOrchestrationCreated.InstanceId, + Name = subOrchestrationCreated.Name, + Version = subOrchestrationCreated.Version, + }; + break; + case EventType.SubOrchestrationInstanceCompleted: + var subOrchestrationCompleted = (SubOrchestrationInstanceCompletedEvent)e; + payload.SubOrchestrationInstanceCompleted = new Proto.SubOrchestrationInstanceCompletedEvent + { + Result = subOrchestrationCompleted.Result, + TaskScheduledId = subOrchestrationCompleted.TaskScheduledId, + }; + break; + case EventType.SubOrchestrationInstanceFailed: + var subOrchestrationFailed = (SubOrchestrationInstanceFailedEvent)e; + payload.SubOrchestrationInstanceFailed = new Proto.SubOrchestrationInstanceFailedEvent + { + FailureDetails = GetFailureDetails(subOrchestrationFailed.FailureDetails), + TaskScheduledId = subOrchestrationFailed.TaskScheduledId, + }; + break; + case EventType.TimerCreated: + var timerCreatedEvent = (TimerCreatedEvent)e; + payload.TimerCreated = new Proto.TimerCreatedEvent + { + FireAt = Timestamp.FromDateTime(timerCreatedEvent.FireAt), + }; + break; + case EventType.TimerFired: + var timerFiredEvent = (TimerFiredEvent)e; + payload.TimerFired = new Proto.TimerFiredEvent + { + FireAt = Timestamp.FromDateTime(timerFiredEvent.FireAt), + TimerId = timerFiredEvent.TimerId, + }; + break; + case EventType.OrchestratorStarted: + // This event has no data + payload.OrchestratorStarted = new Proto.OrchestratorStartedEvent(); + break; + case EventType.OrchestratorCompleted: + // This event has no data + payload.OrchestratorCompleted = new Proto.OrchestratorCompletedEvent(); + break; + case EventType.GenericEvent: + var genericEvent = (GenericEvent)e; + payload.GenericEvent = new Proto.GenericEvent + { + Data = genericEvent.Data, + }; + break; + case EventType.HistoryState: + var historyStateEvent = (HistoryStateEvent)e; + payload.HistoryState = new Proto.HistoryStateEvent + { + OrchestrationState = new Proto.OrchestrationState + { + InstanceId = historyStateEvent.State.OrchestrationInstance.InstanceId, + Name = historyStateEvent.State.Name, + Version = historyStateEvent.State.Version, + Input = historyStateEvent.State.Input, + Output = historyStateEvent.State.Output, + ScheduledStartTimestamp = historyStateEvent.State.ScheduledStartTime == null ? null : Timestamp.FromDateTime(historyStateEvent.State.ScheduledStartTime.Value), + CreatedTimestamp = Timestamp.FromDateTime(historyStateEvent.State.CreatedTime), + LastUpdatedTimestamp = Timestamp.FromDateTime(historyStateEvent.State.LastUpdatedTime), + OrchestrationStatus = (Proto.OrchestrationStatus)historyStateEvent.State.OrchestrationStatus, + CustomStatus = historyStateEvent.State.Status, + Tags = { historyStateEvent.State.Tags }, + }, + }; + break; + case EventType.ExecutionSuspended: + var suspendedEvent = (ExecutionSuspendedEvent)e; + payload.ExecutionSuspended = new Proto.ExecutionSuspendedEvent + { + Input = suspendedEvent.Reason, + }; + break; + case EventType.ExecutionResumed: + var resumedEvent = (ExecutionResumedEvent)e; + payload.ExecutionResumed = new Proto.ExecutionResumedEvent + { + Input = resumedEvent.Reason, + }; + break; + default: + throw new NotSupportedException($"Found unsupported history event '{e.EventType}'."); + } + + return payload; + } + + public static OrchestratorAction ToOrchestratorAction(Proto.OrchestratorAction a) + { + switch (a.OrchestratorActionTypeCase) + { + case Proto.OrchestratorAction.OrchestratorActionTypeOneofCase.ScheduleTask: + return new ScheduleTaskOrchestratorAction + { + Id = a.Id, + Input = a.ScheduleTask.Input, + Name = a.ScheduleTask.Name, + Version = a.ScheduleTask.Version, + }; + case Proto.OrchestratorAction.OrchestratorActionTypeOneofCase.CreateSubOrchestration: + return new CreateSubOrchestrationAction + { + Id = a.Id, + Input = a.CreateSubOrchestration.Input, + Name = a.CreateSubOrchestration.Name, + InstanceId = a.CreateSubOrchestration.InstanceId, + Tags = null, // TODO + Version = a.CreateSubOrchestration.Version, + }; + case Proto.OrchestratorAction.OrchestratorActionTypeOneofCase.CreateTimer: + return new CreateTimerOrchestratorAction + { + Id = a.Id, + FireAt = a.CreateTimer.FireAt.ToDateTime(), + }; + case Proto.OrchestratorAction.OrchestratorActionTypeOneofCase.SendEvent: + return new SendEventOrchestratorAction + { + Id = a.Id, + Instance = new OrchestrationInstance + { + InstanceId = a.SendEvent.Instance.InstanceId, + ExecutionId = a.SendEvent.Instance.ExecutionId, + }, + EventName = a.SendEvent.Name, + EventData = a.SendEvent.Data, + }; + case Proto.OrchestratorAction.OrchestratorActionTypeOneofCase.CompleteOrchestration: + var completedAction = a.CompleteOrchestration; + var action = new OrchestrationCompleteOrchestratorAction + { + Id = a.Id, + OrchestrationStatus = (OrchestrationStatus)completedAction.OrchestrationStatus, + Result = completedAction.Result, + Details = completedAction.Details, + FailureDetails = GetFailureDetails(completedAction.FailureDetails), + NewVersion = completedAction.NewVersion, + }; + + if (completedAction.CarryoverEvents?.Count > 0) + { + foreach (var e in completedAction.CarryoverEvents) + { + // Only raised events are supported for carryover + if (e.EventRaised is Proto.EventRaisedEvent eventRaised) + { + action.CarryoverEvents.Add(new EventRaisedEvent(e.EventId, eventRaised.Input) + { + Name = eventRaised.Name, + }); + } + + } + } + + return action; + default: + throw new NotSupportedException($"Received unsupported action type '{a.OrchestratorActionTypeCase}'."); + } + } + + public static string Base64Encode(IMessage message) + { + // Create a serialized payload using lower-level protobuf APIs. We do this to avoid allocating + // byte[] arrays for every request, which would otherwise put a heavy burden on the GC. Unfortunately + // the protobuf API version we're using doesn't currently have memory-efficient serialization APIs. + int messageSize = message.CalculateSize(); + byte[] rentedBuffer = ArrayPool.Shared.Rent(messageSize); + try + { + using MemoryStream intermediateBufferStream = new(rentedBuffer, 0, messageSize); + CodedOutputStream protobufOutputStream = new(intermediateBufferStream); + protobufOutputStream.WriteRawMessage(message); + protobufOutputStream.Flush(); + return Convert.ToBase64String(rentedBuffer, 0, messageSize); + } + finally + { + ArrayPool.Shared.Return(rentedBuffer); + } + } + + internal static FailureDetails? GetFailureDetails(Proto.TaskFailureDetails? failureDetails) + { + if (failureDetails == null) + { + return null; + } + + return new FailureDetails( + failureDetails.ErrorType, + failureDetails.ErrorMessage, + failureDetails.StackTrace, + GetFailureDetails(failureDetails.InnerFailure), + failureDetails.IsNonRetriable); + } + + internal static Proto.TaskFailureDetails? GetFailureDetails(FailureDetails? failureDetails) + { + if (failureDetails == null) + { + return null; + } + + return new Proto.TaskFailureDetails + { + ErrorType = failureDetails.ErrorType, + ErrorMessage = failureDetails.ErrorMessage, + StackTrace = failureDetails.StackTrace, + InnerFailure = GetFailureDetails(failureDetails.InnerFailure), + IsNonRetriable = failureDetails.IsNonRetriable, + }; + } + + internal static OrchestrationQuery ToOrchestrationQuery(Proto.QueryInstancesRequest request) + { + var query = new OrchestrationQuery() + { + RuntimeStatus = request.Query.RuntimeStatus?.Select(status => (OrchestrationStatus)status).ToList(), + CreatedTimeFrom = request.Query.CreatedTimeFrom?.ToDateTime(), + CreatedTimeTo = request.Query.CreatedTimeTo?.ToDateTime(), + TaskHubNames = request.Query.TaskHubNames, + PageSize = request.Query.MaxInstanceCount, + ContinuationToken = request.Query.ContinuationToken, + InstanceIdPrefix = request.Query.InstanceIdPrefix, + FetchInputsAndOutputs = request.Query.FetchInputsAndOutputs, + }; + + return query; + } + + internal static Proto.QueryInstancesResponse CreateQueryInstancesResponse(OrchestrationQueryResult result, Proto.QueryInstancesRequest request) + { + Proto.QueryInstancesResponse response = new Proto.QueryInstancesResponse + { + ContinuationToken = result.ContinuationToken + }; + foreach (OrchestrationState state in result.OrchestrationState) + { + var orchestrationState = new Proto.OrchestrationState + { + InstanceId = state.OrchestrationInstance.InstanceId, + Name = state.Name, + Version = state.Version, + Input = state.Input, + Output = state.Output, + ScheduledStartTimestamp = state.ScheduledStartTime == null ? null : Timestamp.FromDateTime(state.ScheduledStartTime.Value), + CreatedTimestamp = Timestamp.FromDateTime(state.CreatedTime), + LastUpdatedTimestamp = Timestamp.FromDateTime(state.LastUpdatedTime), + OrchestrationStatus = (Proto.OrchestrationStatus)state.OrchestrationStatus, + CustomStatus = state.Status, + }; + response.OrchestrationState.Add(orchestrationState); + } + return response; + } + + internal static PurgeInstanceFilter ToPurgeInstanceFilter(Proto.PurgeInstancesRequest request) + { + var purgeInstanceFilter = new PurgeInstanceFilter( + request.PurgeInstanceFilter.CreatedTimeFrom.ToDateTime(), + request.PurgeInstanceFilter.CreatedTimeTo?.ToDateTime(), + request.PurgeInstanceFilter.RuntimeStatus?.Select(status => (OrchestrationStatus)status).ToList() + ); + return purgeInstanceFilter; + } + + internal static Proto.PurgeInstancesResponse CreatePurgeInstancesResponse(PurgeResult result) + { + Proto.PurgeInstancesResponse response = new Proto.PurgeInstancesResponse + { + DeletedInstanceCount = result.DeletedInstanceCount + }; + return response; + } +} diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/Grpc/TaskHubGrpcServer.cs b/test/Grpc.IntegrationTests/GrpcSidecar/Grpc/TaskHubGrpcServer.cs new file mode 100644 index 000000000..167a437fc --- /dev/null +++ b/test/Grpc.IntegrationTests/GrpcSidecar/Grpc/TaskHubGrpcServer.cs @@ -0,0 +1,626 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Concurrent; +using System.Diagnostics; +using DurableTask.Core; +using DurableTask.Core.History; +using DurableTask.Core.Query; +using Microsoft.DurableTask.Sidecar.Dispatcher; +using Google.Protobuf.WellKnownTypes; +using Grpc.Core; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using P = Microsoft.DurableTask.Protobuf; + +namespace Microsoft.DurableTask.Sidecar.Grpc; + +public class TaskHubGrpcServer : P.TaskHubSidecarService.TaskHubSidecarServiceBase, ITaskExecutor +{ + static readonly Task EmptyCompleteTaskResponse = Task.FromResult(new P.CompleteTaskResponse()); + + readonly ConcurrentDictionary> pendingOrchestratorTasks = new(StringComparer.OrdinalIgnoreCase); + readonly ConcurrentDictionary> pendingActivityTasks = new(StringComparer.OrdinalIgnoreCase); + + readonly ILogger log; + readonly IOrchestrationService service; + readonly IOrchestrationServiceClient client; + readonly IHostApplicationLifetime hostLifetime; + readonly IOptions options; + readonly TaskHubDispatcherHost dispatcherHost; + readonly IsConnectedSignal isConnectedSignal = new(); + readonly SemaphoreSlim sendWorkItemLock = new(initialCount: 1); + + // Initialized when a client connects to this service to receive work-item commands. + IServerStreamWriter? workerToClientStream; + + public TaskHubGrpcServer( + IHostApplicationLifetime hostApplicationLifetime, + ILoggerFactory loggerFactory, + IOrchestrationService service, + IOrchestrationServiceClient client, + IOptions options) + { + ArgumentNullException.ThrowIfNull(hostApplicationLifetime, nameof(hostApplicationLifetime)); + ArgumentNullException.ThrowIfNull(loggerFactory, nameof(loggerFactory)); + ArgumentNullException.ThrowIfNull(service, nameof(service)); + ArgumentNullException.ThrowIfNull(client, nameof(client)); + ArgumentNullException.ThrowIfNull(options, nameof(options)); + + this.service = service; + this.client = client; + this.log = loggerFactory.CreateLogger("Microsoft.DurableTask.Sidecar"); + this.dispatcherHost = new TaskHubDispatcherHost( + loggerFactory, + trafficSignal: this.isConnectedSignal, + orchestrationService: service, + taskExecutor: this); + + this.hostLifetime = hostApplicationLifetime; + this.options = options; + this.hostLifetime.ApplicationStarted.Register(this.OnApplicationStarted); + this.hostLifetime.ApplicationStopping.Register(this.OnApplicationStopping); + } + + async void OnApplicationStarted() + { + if (this.options.Value.Mode == TaskHubGrpcServerMode.ApiServerAndDispatcher) + { + // Wait for a client connection to be established before starting the dispatcher host. + // This ensures we don't do any wasteful polling of resources if no clients are available to process events. + await this.WaitForWorkItemClientConnection(); + await this.dispatcherHost.StartAsync(this.hostLifetime.ApplicationStopping); + } + } + + async void OnApplicationStopping() + { + if (this.options.Value.Mode == TaskHubGrpcServerMode.ApiServerAndDispatcher) + { + // Give a maximum of 60 minutes for outstanding tasks to complete. + // REVIEW: Is this enough? What if there's an activity job that takes 4 hours to complete? Should this be configurable? + using CancellationTokenSource timeoutCts = new(TimeSpan.FromMinutes(60)); + await this.dispatcherHost.StopAsync(timeoutCts.Token); + } + } + + /// + /// Blocks until a remote client calls the operation to start fetching work items. + /// + /// Returns a task that completes once a work-item client is connected. + async Task WaitForWorkItemClientConnection() + { + Stopwatch waitTimeStopwatch = Stopwatch.StartNew(); + TimeSpan logInterval = TimeSpan.FromMinutes(1); + + try + { + while (!await this.isConnectedSignal.WaitAsync(logInterval, this.hostLifetime.ApplicationStopping)) + { + this.log.WaitingForClientConnection(waitTimeStopwatch.Elapsed); + } + } + catch (OperationCanceledException) + { + // shutting down + } + } + + public override Task Hello(Empty request, ServerCallContext context) => Task.FromResult(new Empty()); + + public override Task CreateTaskHub(P.CreateTaskHubRequest request, ServerCallContext context) + { + this.service.CreateAsync(request.RecreateIfExists); + return Task.FromResult(new P.CreateTaskHubResponse()); + } + + public override Task DeleteTaskHub(P.DeleteTaskHubRequest request, ServerCallContext context) + { + this.service.DeleteAsync(); + return Task.FromResult(new P.DeleteTaskHubResponse()); + } + + public override async Task StartInstance(P.CreateInstanceRequest request, ServerCallContext context) + { + var instance = new OrchestrationInstance + { + InstanceId = request.InstanceId ?? Guid.NewGuid().ToString("N"), + ExecutionId = Guid.NewGuid().ToString(), + }; + + // TODO: Structured logging + this.log.LogInformation("Creating a new instance with ID = {instanceID}", instance.InstanceId); + + try + { + await this.client.CreateTaskOrchestrationAsync( + new TaskMessage + { + Event = new ExecutionStartedEvent(-1, request.Input) + { + Name = request.Name, + Version = request.Version, + OrchestrationInstance = instance, + Tags = request.Tags.ToDictionary(kvp => kvp.Key, kvp => kvp.Value), + }, + OrchestrationInstance = instance, + }); + } + catch (Exception e) + { + // TODO: Structured logging + this.log.LogError(e, "An error occurred when trying to create a new instance"); + throw; + } + + return new P.CreateInstanceResponse + { + InstanceId = instance.InstanceId, + }; + } + + public override async Task RaiseEvent(P.RaiseEventRequest request, ServerCallContext context) + { + try + { + await this.client.SendTaskOrchestrationMessageAsync( + new TaskMessage + { + Event = new EventRaisedEvent(-1, request.Input) + { + Name = request.Name, + }, + OrchestrationInstance = new OrchestrationInstance + { + InstanceId = request.InstanceId, + }, + }); + } + catch (Exception e) + { + // TODO: Structured logging + this.log.LogError(e, "An error occurred when trying to raise an event."); + throw; + } + + // No fields in the response + return new P.RaiseEventResponse(); + } + + public override async Task TerminateInstance(P.TerminateRequest request, ServerCallContext context) + { + try + { + await this.client.ForceTerminateTaskOrchestrationAsync( + request.InstanceId, + request.Output); + } + catch (Exception e) + { + // TODO: Structured logging + this.log.LogError(e, "An error occurred when trying to terminate an instance."); + throw; + } + + // No fields in the response + return new P.TerminateResponse(); + } + + public override async Task GetInstance(P.GetInstanceRequest request, ServerCallContext context) + { + OrchestrationState state = await this.client.GetOrchestrationStateAsync(request.InstanceId, executionId: null); + if (state == null) + { + return new P.GetInstanceResponse() { Exists = false }; + } + + return CreateGetInstanceResponse(state, request); + } + + public override async Task QueryInstances(P.QueryInstancesRequest request, ServerCallContext context) + { + if (this.client is IOrchestrationServiceQueryClient queryClient) + { + OrchestrationQuery query = ProtobufUtils.ToOrchestrationQuery(request); + OrchestrationQueryResult result = await queryClient.GetOrchestrationWithQueryAsync(query, context.CancellationToken); + return ProtobufUtils.CreateQueryInstancesResponse(result, request); + } + else + { + throw new NotSupportedException($"{this.client.GetType().Name} doesn't support query operations."); + } + } + + public override async Task PurgeInstances(P.PurgeInstancesRequest request, ServerCallContext context) + { + if (this.client is IOrchestrationServicePurgeClient purgeClient) + { + PurgeResult result; + switch (request.RequestCase) + { + case P.PurgeInstancesRequest.RequestOneofCase.InstanceId: + result = await purgeClient.PurgeInstanceStateAsync(request.InstanceId); + break; + + case P.PurgeInstancesRequest.RequestOneofCase.PurgeInstanceFilter: + PurgeInstanceFilter purgeInstanceFilter = ProtobufUtils.ToPurgeInstanceFilter(request); + result = await purgeClient.PurgeInstanceStateAsync(purgeInstanceFilter); + break; + + default: + throw new ArgumentException($"Unknown purge request type '{request.RequestCase}'."); + } + return ProtobufUtils.CreatePurgeInstancesResponse(result); + } + else + { + throw new NotSupportedException($"{this.client.GetType().Name} doesn't support purge operations."); + } + } + + public override async Task WaitForInstanceStart(P.GetInstanceRequest request, ServerCallContext context) + { + while (true) + { + // Keep fetching the status until we get one of the states we care about + OrchestrationState state = await this.client.GetOrchestrationStateAsync(request.InstanceId, executionId: null); + if (state != null && state.OrchestrationStatus != OrchestrationStatus.Pending) + { + return CreateGetInstanceResponse(state, request); + } + + // TODO: Backoff strategy if we're delaying for a long time. + // The cancellation token is what will break us out of this loop if the orchestration + // never leaves the "Pending" state. + await Task.Delay(TimeSpan.FromMilliseconds(500), context.CancellationToken); + } + } + + public override async Task WaitForInstanceCompletion(P.GetInstanceRequest request, ServerCallContext context) + { + OrchestrationState state = await this.client.WaitForOrchestrationAsync( + request.InstanceId, + executionId: null, + timeout: Timeout.InfiniteTimeSpan, + context.CancellationToken); + + return CreateGetInstanceResponse(state, request); + } + + static P.GetInstanceResponse CreateGetInstanceResponse(OrchestrationState state, P.GetInstanceRequest request) + { + return new P.GetInstanceResponse + { + Exists = true, + OrchestrationState = new P.OrchestrationState + { + InstanceId = state.OrchestrationInstance.InstanceId, + Name = state.Name, + OrchestrationStatus = (P.OrchestrationStatus)state.OrchestrationStatus, + CreatedTimestamp = Timestamp.FromDateTime(state.CreatedTime), + LastUpdatedTimestamp = Timestamp.FromDateTime(state.LastUpdatedTime), + Input = request.GetInputsAndOutputs ? state.Input : null, + Output = request.GetInputsAndOutputs ? state.Output : null, + CustomStatus = request.GetInputsAndOutputs ? state.Status : null, + FailureDetails = request.GetInputsAndOutputs ? GetFailureDetails(state.FailureDetails) : null, + Tags = { state.Tags } + } + }; + } + + public override async Task SuspendInstance(P.SuspendRequest request, ServerCallContext context) + { + TaskMessage taskMessage = new() + { + OrchestrationInstance = new OrchestrationInstance { InstanceId = request.InstanceId }, + Event = new ExecutionSuspendedEvent(-1, request.Reason), + }; + + await this.client.SendTaskOrchestrationMessageAsync(taskMessage); + return new P.SuspendResponse(); + } + + public override async Task ResumeInstance(P.ResumeRequest request, ServerCallContext context) + { + TaskMessage taskMessage = new() + { + OrchestrationInstance = new OrchestrationInstance { InstanceId = request.InstanceId }, + Event = new ExecutionResumedEvent(-1, request.Reason), + }; + + await this.client.SendTaskOrchestrationMessageAsync(taskMessage); + return new P.ResumeResponse(); + } + + static P.TaskFailureDetails? GetFailureDetails(FailureDetails? failureDetails) + { + if (failureDetails == null) + { + return null; + } + + return new P.TaskFailureDetails + { + ErrorType = failureDetails.ErrorType, + ErrorMessage = failureDetails.ErrorMessage, + StackTrace = failureDetails.StackTrace, + InnerFailure = GetFailureDetails(failureDetails.InnerFailure), + }; + } + + /// + /// Invoked by the remote SDK over gRPC when an orchestrator task (episode) is completed. + /// + /// Details about the orchestration execution, including the list of orchestrator actions. + /// Context for the server-side gRPC call. + /// Returns an empty ack back to the remote SDK that we've received the completion. + public override Task CompleteOrchestratorTask(P.OrchestratorResponse request, ServerCallContext context) + { + if (!this.pendingOrchestratorTasks.TryRemove( + request.InstanceId, + out TaskCompletionSource? tcs)) + { + // TODO: Log? + throw new RpcException(new Status(StatusCode.NotFound, $"Orchestration not found")); + } + + OrchestratorExecutionResult result = new() + { + Actions = request.Actions.Select(ProtobufUtils.ToOrchestratorAction), + CustomStatus = request.CustomStatus, + }; + + tcs.TrySetResult(result); + + return EmptyCompleteTaskResponse; + } + + /// + /// Invoked by the remote SDK over gRPC when an activity task (episode) is completed. + /// + /// Details about the completed activity task, including the output. + /// Context for the server-side gRPC call. + /// Returns an empty ack back to the remote SDK that we've received the completion. + public override Task CompleteActivityTask(P.ActivityResponse response, ServerCallContext context) + { + string taskIdKey = GetTaskIdKey(response.InstanceId, response.TaskId); + if (!this.pendingActivityTasks.TryRemove(taskIdKey, out TaskCompletionSource? tcs)) + { + // TODO: Log? + throw new RpcException(new Status(StatusCode.NotFound, $"Activity not found")); + } + + HistoryEvent resultEvent; + if (response.FailureDetails == null) + { + resultEvent = new TaskCompletedEvent(-1, response.TaskId, response.Result); + } + else + { + resultEvent = new TaskFailedEvent( + eventId: -1, + taskScheduledId: response.TaskId, + reason: null, + details: null, + failureDetails: ProtobufUtils.GetFailureDetails(response.FailureDetails)); + } + + tcs.TrySetResult(new ActivityExecutionResult { ResponseEvent = resultEvent }); + return EmptyCompleteTaskResponse; + } + + public override async Task GetWorkItems(P.GetWorkItemsRequest request, IServerStreamWriter responseStream, ServerCallContext context) + { + // Use a lock to mitigate the race condition where we signal the dispatch host to start but haven't + // yet saved a reference to the client response stream. + lock (this.isConnectedSignal) + { + int retryCount = 0; + while (!this.isConnectedSignal.Set()) + { + // Retries are needed when a client (like a test suite) connects and disconnects rapidly, causing a race + // condition where we don't reset the signal quickly enough to avoid ResourceExausted errors. + if (retryCount <= 5) + { + Thread.Sleep(10); // Can't use await inside the body of a lock statement so we have to block the thread + } + else + { + throw new RpcException(new Status(StatusCode.ResourceExhausted, "Another client is already connected")); + } + } + + this.log.ClientConnected(context.Peer, context.Deadline); + this.workerToClientStream = responseStream; + } + + try + { + await Task.Delay(Timeout.InfiniteTimeSpan, context.CancellationToken); + } + catch (OperationCanceledException) + { + this.log.ClientDisconnected(context.Peer); + } + finally + { + // Resetting this signal causes the dispatchers to stop fetching new work. + this.isConnectedSignal.Reset(); + + // Transition back to the "waiting for connection" state. + // This background task is just to log "waiting for connection" messages. + _ = Task.Run(this.WaitForWorkItemClientConnection); + } + } + + /// + /// Invoked by the when a work item is available, proxies the call to execute an orchestrator over a gRPC channel. + /// + /// + async Task ITaskExecutor.ExecuteOrchestrator( + OrchestrationInstance instance, + IEnumerable pastEvents, + IEnumerable newEvents) + { + // Create a task completion source that represents the async completion of the orchestrator execution. + // This must be done before we start the orchestrator execution. + TaskCompletionSource tcs = + this.CreateTaskCompletionSourceForOrchestrator(instance.InstanceId); + + try + { + await this.SendWorkItemToClientAsync(new P.WorkItem + { + OrchestratorRequest = new P.OrchestratorRequest + { + InstanceId = instance.InstanceId, + ExecutionId = instance.ExecutionId, + NewEvents = { newEvents.Select(ProtobufUtils.ToHistoryEventProto) }, + PastEvents = { pastEvents.Select(ProtobufUtils.ToHistoryEventProto) }, + } + }); + } + catch + { + // Remove the TaskCompletionSource that we just created + this.RemoveOrchestratorTaskCompletionSource(instance.InstanceId); + throw; + } + + // The TCS will be completed on the message stream handler when it gets a response back from the remote process + // TODO: How should we handle timeouts if the remote process never sends a response? + // Probably need to have a static timeout (e.g. 5 minutes). + return await tcs.Task; + } + + async Task ITaskExecutor.ExecuteActivity(OrchestrationInstance instance, TaskScheduledEvent activityEvent) + { + // Create a task completion source that represents the async completion of the activity. + // This must be done before we start the activity execution. + TaskCompletionSource tcs = this.CreateTaskCompletionSourceForActivity( + instance.InstanceId, + activityEvent.EventId); + + try + { + await this.SendWorkItemToClientAsync(new P.WorkItem + { + ActivityRequest = new P.ActivityRequest + { + Name = activityEvent.Name, + Version = activityEvent.Version, + Input = activityEvent.Input, + TaskId = activityEvent.EventId, + OrchestrationInstance = new P.OrchestrationInstance + { + InstanceId = instance.InstanceId, + ExecutionId = instance.ExecutionId, + }, + } + }); + } + catch + { + // Remove the TaskCompletionSource that we just created + this.RemoveActivityTaskCompletionSource(instance.InstanceId, activityEvent.EventId); + throw; + } + + // The TCS will be completed on the message stream handler when it gets a response back from the remote process. + // TODO: How should we handle timeouts if the remote process never sends a response? + // Probably need a timeout feature for activities and/or a heartbeat API that activities + // can use to signal that they're still running. + return await tcs.Task; + } + + async Task SendWorkItemToClientAsync(P.WorkItem workItem) + { + IServerStreamWriter outputStream; + + // Use a lock to mitigate the race condition where we signal the dispatch host to start but haven't + // yet saved a reference to the client response stream. + lock (this.isConnectedSignal) + { + outputStream = this.workerToClientStream ?? + throw new Exception("TODO: No client is connected! Need to wait until a client connects before executing!"); + } + + // The gRPC channel can only handle one message at a time, so we need to serialize access to it. + await this.sendWorkItemLock.WaitAsync(); + try + { + await outputStream.WriteAsync(workItem); + } + finally + { + this.sendWorkItemLock.Release(); + } + } + + TaskCompletionSource CreateTaskCompletionSourceForOrchestrator(string instanceId) + { + TaskCompletionSource tcs = new(); + this.pendingOrchestratorTasks.TryAdd(instanceId, tcs); + return tcs; + } + + void RemoveOrchestratorTaskCompletionSource(string instanceId) + { + this.pendingOrchestratorTasks.TryRemove(instanceId, out _); + } + + TaskCompletionSource CreateTaskCompletionSourceForActivity(string instanceId, int taskId) + { + string taskIdKey = GetTaskIdKey(instanceId, taskId); + TaskCompletionSource tcs = new(); + this.pendingActivityTasks.TryAdd(taskIdKey, tcs); + return tcs; + } + + void RemoveActivityTaskCompletionSource(string instanceId, int taskId) + { + string taskIdKey = GetTaskIdKey(instanceId, taskId); + this.pendingActivityTasks.TryRemove(taskIdKey, out _); + } + + static string GetTaskIdKey(string instanceId, int taskId) + { + return string.Concat(instanceId, "__", taskId.ToString()); + } + + /// + /// A implementation that is used to control whether the task hub + /// dispatcher can fetch new work-items, based on whether a client is currently connected. + /// + class IsConnectedSignal : ITrafficSignal + { + readonly AsyncManualResetEvent isConnectedEvent = new(isSignaled: false); + + /// + public string WaitReason => "Waiting for a client to connect"; + + /// + /// Blocks the caller until the method is called, which means a client is connected. + /// + /// + public Task WaitAsync(TimeSpan waitTime, CancellationToken cancellationToken) + { + return this.isConnectedEvent.WaitAsync(waitTime, cancellationToken); + } + + /// + /// Signals the dispatchers to start fetching new work-items. + /// + /// + /// Returns true if the current thread transitioned the event to the "signaled" state; + /// otherwise false, meaning some other thread already called on this signal. + /// + public bool Set() => this.isConnectedEvent.Set(); + + /// + /// Causes the dispatchers to stop fetching new work-items. + /// + public void Reset() => this.isConnectedEvent.Reset(); + } +} diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/Grpc/TaskHubGrpcServerOptions.cs b/test/Grpc.IntegrationTests/GrpcSidecar/Grpc/TaskHubGrpcServerOptions.cs new file mode 100644 index 000000000..ffeeacd5e --- /dev/null +++ b/test/Grpc.IntegrationTests/GrpcSidecar/Grpc/TaskHubGrpcServerOptions.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.DurableTask.Sidecar.Grpc; + +/// +/// Options for configuring the task hub gRPC server. +/// +public class TaskHubGrpcServerOptions +{ + /// + /// The high-level mode of operation for the gRPC server. + /// + public TaskHubGrpcServerMode Mode { get; set; } +} + +/// +/// A set of options that determine what capabilities are enabled for the gRPC server. +/// +public enum TaskHubGrpcServerMode +{ + /// + /// The gRPC server handles both orchestration dispatching and management API requests. + /// + ApiServerAndDispatcher, + + /// + /// The gRPC server handles management API requests but not orchestration dispatching. + /// + ApiServerOnly, +} \ No newline at end of file diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/InMemoryOrchestrationService.cs b/test/Grpc.IntegrationTests/GrpcSidecar/InMemoryOrchestrationService.cs new file mode 100644 index 000000000..481fdc3ff --- /dev/null +++ b/test/Grpc.IntegrationTests/GrpcSidecar/InMemoryOrchestrationService.cs @@ -0,0 +1,692 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using System.Threading.Channels; +using DurableTask.Core; +using DurableTask.Core.Exceptions; +using DurableTask.Core.History; +using DurableTask.Core.Query; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.DurableTask.Sidecar; + +public class InMemoryOrchestrationService : IOrchestrationService, IOrchestrationServiceClient, IOrchestrationServiceQueryClient, IOrchestrationServicePurgeClient +{ + readonly InMemoryQueue activityQueue = new(); + readonly InMemoryInstanceStore instanceStore; + readonly ILogger logger; + + public int TaskOrchestrationDispatcherCount => 1; + + public int TaskActivityDispatcherCount => 1; + + public int MaxConcurrentTaskOrchestrationWorkItems => Environment.ProcessorCount; + + public int MaxConcurrentTaskActivityWorkItems => Environment.ProcessorCount; + + public BehaviorOnContinueAsNew EventBehaviourForContinueAsNew => BehaviorOnContinueAsNew.Carryover; + + public InMemoryOrchestrationService(ILoggerFactory? loggerFactory = null) + { + this.logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger("Microsoft.DurableTask.Sidecar.InMemoryStorageProvider"); + this.instanceStore = new InMemoryInstanceStore(this.logger); + } + + public Task AbandonTaskActivityWorkItemAsync(TaskActivityWorkItem workItem) + { + this.logger.LogWarning("Abandoning task activity work item {id}", workItem.Id); + this.activityQueue.Enqueue(workItem.TaskMessage); + return Task.CompletedTask; + } + + public Task AbandonTaskOrchestrationWorkItemAsync(TaskOrchestrationWorkItem workItem) + { + this.instanceStore.AbandonInstance(workItem.NewMessages); + return Task.CompletedTask; + } + + public Task CompleteTaskActivityWorkItemAsync(TaskActivityWorkItem workItem, TaskMessage responseMessage) + { + this.instanceStore.AddMessage(responseMessage); + return Task.CompletedTask; + } + + public Task CompleteTaskOrchestrationWorkItemAsync( + TaskOrchestrationWorkItem workItem, + OrchestrationRuntimeState newOrchestrationRuntimeState, + IList outboundMessages, + IList orchestratorMessages, + IList timerMessages, + TaskMessage continuedAsNewMessage, + OrchestrationState orchestrationState) + { + this.instanceStore.SaveState( + runtimeState: newOrchestrationRuntimeState, + statusRecord: orchestrationState, + newMessages: orchestratorMessages.Union(timerMessages).Append(continuedAsNewMessage).Where(msg => msg != null)); + + this.activityQueue.Enqueue(outboundMessages); + return Task.CompletedTask; + } + + public Task CreateAsync() => Task.CompletedTask; + + public Task CreateAsync(bool recreateInstanceStore) + { + if (recreateInstanceStore) + { + this.instanceStore.Reset(); + } + return Task.CompletedTask; + } + + public Task CreateIfNotExistsAsync() => Task.CompletedTask; + + public Task CreateTaskOrchestrationAsync(TaskMessage creationMessage) + { + return this.CreateTaskOrchestrationAsync( + creationMessage, + new[] { OrchestrationStatus.Pending, OrchestrationStatus.Running }); + } + + public Task CreateTaskOrchestrationAsync(TaskMessage creationMessage, OrchestrationStatus[]? dedupeStatuses) + { + // Lock the instance store to prevent multiple "create" threads from racing with each other. + lock (this.instanceStore) + { + string instanceId = creationMessage.OrchestrationInstance.InstanceId; + if (this.instanceStore.TryGetState(instanceId, out OrchestrationState? statusRecord) && + dedupeStatuses != null && + dedupeStatuses.Contains(statusRecord.OrchestrationStatus)) + { + throw new OrchestrationAlreadyExistsException($"An orchestration with id '{instanceId}' already exists. It's in the {statusRecord.OrchestrationStatus} state."); + } + + this.instanceStore.AddMessage(creationMessage); + } + + return Task.CompletedTask; + } + + public Task DeleteAsync() => this.DeleteAsync(true); + + public Task DeleteAsync(bool deleteInstanceStore) + { + if (deleteInstanceStore) + { + this.instanceStore.Reset(); + } + return Task.CompletedTask; + } + + public Task ForceTerminateTaskOrchestrationAsync(string instanceId, string reason) + { + var taskMessage = new TaskMessage + { + OrchestrationInstance = new OrchestrationInstance { InstanceId = instanceId }, + Event = new ExecutionTerminatedEvent(-1, reason), + }; + + return this.SendTaskOrchestrationMessageAsync(taskMessage); + } + + public int GetDelayInSecondsAfterOnFetchException(Exception exception) + { + return exception is OperationCanceledException ? 0 : 1; + } + + public int GetDelayInSecondsAfterOnProcessException(Exception exception) + { + return exception is OperationCanceledException ? 0 : 1; + } + + public Task GetOrchestrationHistoryAsync(string instanceId, string executionId) + { + // Also not supported in the emulator + throw new NotImplementedException(); + } + + public async Task> GetOrchestrationStateAsync(string instanceId, bool allExecutions) + { + OrchestrationState state = await this.GetOrchestrationStateAsync(instanceId, executionId: null); + return new[] { state }; + } + + public Task GetOrchestrationStateAsync(string instanceId, string? executionId) + { + if (this.instanceStore.TryGetState(instanceId, out OrchestrationState? statusRecord)) + { + if (executionId == null || executionId == statusRecord.OrchestrationInstance.ExecutionId) + { + return Task.FromResult(statusRecord); + } + } + + return Task.FromResult(null!); + } + + public bool IsMaxMessageCountExceeded(int currentMessageCount, OrchestrationRuntimeState runtimeState) => false; + + public async Task LockNextTaskActivityWorkItem(TimeSpan receiveTimeout, CancellationToken cancellationToken) + { + TaskMessage message = await this.activityQueue.DequeueAsync(cancellationToken); + return new TaskActivityWorkItem + { + Id = message.SequenceNumber.ToString(), + LockedUntilUtc = DateTime.MaxValue, + TaskMessage = message, + }; + } + + public async Task LockNextTaskOrchestrationWorkItemAsync(TimeSpan receiveTimeout, CancellationToken cancellationToken) + { + var (instanceId, runtimeState, messages) = await this.instanceStore.GetNextReadyToRunInstanceAsync(cancellationToken); + + return new TaskOrchestrationWorkItem + { + InstanceId = instanceId, + OrchestrationRuntimeState = runtimeState, + NewMessages = messages, + LockedUntilUtc = DateTime.MaxValue, + }; + } + + public Task PurgeOrchestrationHistoryAsync(DateTime thresholdDateTimeUtc, OrchestrationStateTimeRangeFilterType timeRangeFilterType) + { + // Also not supported in the emulator + throw new NotImplementedException(); + } + + public Task ReleaseTaskOrchestrationWorkItemAsync(TaskOrchestrationWorkItem workItem) + { + this.instanceStore.ReleaseLock(workItem.InstanceId); + return Task.CompletedTask; + } + + public Task RenewTaskActivityWorkItemLockAsync(TaskActivityWorkItem workItem) + { + return Task.FromResult(workItem); // PeekLock isn't supported + } + + public Task RenewTaskOrchestrationWorkItemLockAsync(TaskOrchestrationWorkItem workItem) + { + return Task.CompletedTask; // PeekLock isn't supported + } + + public Task SendTaskOrchestrationMessageAsync(TaskMessage message) + { + this.instanceStore.AddMessage(message); + return Task.CompletedTask; + } + + public Task SendTaskOrchestrationMessageBatchAsync(params TaskMessage[] messages) + { + // NOTE: This is not transactionally consistent - some messages may get processed earlier than others. + foreach (TaskMessage message in messages) + { + this.instanceStore.AddMessage(message); + } + + return Task.CompletedTask; + } + + public Task StartAsync() => Task.CompletedTask; + + public Task StopAsync() => Task.CompletedTask; + + public Task StopAsync(bool isForced) => Task.CompletedTask; + + public async Task WaitForOrchestrationAsync(string instanceId, string executionId, TimeSpan timeout, CancellationToken cancellationToken) + { + if (timeout <= TimeSpan.Zero) + { + return await this.instanceStore.WaitForInstanceAsync(instanceId, cancellationToken); + } + else + { + using CancellationTokenSource timeoutCancellationSource = new(timeout); + using CancellationTokenSource linkedCancellationSource = CancellationTokenSource.CreateLinkedTokenSource( + cancellationToken, + timeoutCancellationSource.Token); + return await this.instanceStore.WaitForInstanceAsync(instanceId, linkedCancellationSource.Token); + } + } + + static bool TryGetScheduledTime(TaskMessage message, out TimeSpan delay) + { + DateTime scheduledTime = default; + if (message.Event is ExecutionStartedEvent startEvent) + { + scheduledTime = startEvent.ScheduledStartTime ?? default; + } + else if (message.Event is TimerFiredEvent timerEvent) + { + scheduledTime = timerEvent.FireAt; + } + + DateTime now = DateTime.UtcNow; + if (scheduledTime > now) + { + delay = scheduledTime - now; + return true; + } + else + { + delay = default; + return false; + } + } + + public Task GetOrchestrationWithQueryAsync(OrchestrationQuery query, CancellationToken cancellationToken) + { + return Task.FromResult(this.instanceStore.GetOrchestrationWithQuery(query)); + } + + public Task PurgeInstanceStateAsync(string instanceId) + { + return Task.FromResult(this.instanceStore.PurgeInstanceState(instanceId)); + } + + public Task PurgeInstanceStateAsync(PurgeInstanceFilter purgeInstanceFilter) + { + return Task.FromResult(this.instanceStore.PurgeInstanceState(purgeInstanceFilter)); + } + + class InMemoryQueue + { + readonly Channel innerQueue = Channel.CreateUnbounded(); + + public void Enqueue(TaskMessage taskMessage) + { + if (TryGetScheduledTime(taskMessage, out TimeSpan delay)) + { + _ = Task.Delay(delay).ContinueWith(t => this.innerQueue.Writer.TryWrite(taskMessage)); + } + else + { + this.innerQueue.Writer.TryWrite(taskMessage); + } + } + + public void Enqueue(IEnumerable messages) + { + foreach (TaskMessage msg in messages) + { + this.Enqueue(msg); + } + } + + public async Task DequeueAsync(CancellationToken cancellationToken) + { + return await this.innerQueue.Reader.ReadAsync(cancellationToken); + } + } + + class InMemoryInstanceStore + { + readonly ConcurrentDictionary store = new(StringComparer.OrdinalIgnoreCase); + readonly ConcurrentDictionary> waiters = new(StringComparer.OrdinalIgnoreCase); + readonly ReadyToRunQueue readyToRunQueue = new(); + + readonly ILogger logger; + + public InMemoryInstanceStore(ILogger logger) => this.logger = logger; + + public void Reset() + { + this.store.Clear(); + this.waiters.Clear(); + this.readyToRunQueue.Reset(); + } + + public async Task<(string, OrchestrationRuntimeState, List)> GetNextReadyToRunInstanceAsync(CancellationToken cancellationToken) + { + SerializedInstanceState state = await this.readyToRunQueue.TakeNextAsync(cancellationToken); + lock (state) + { + List history = state.HistoryEventsJson.Select(e => e!.GetValue()).ToList(); + OrchestrationRuntimeState runtimeState = new(history); + + List messages = state.MessagesJson.Select(node => node!.GetValue()).ToList(); + if (messages == null) + { + throw new InvalidOperationException("Should never load state with zero messages."); + } + + state.IsLoaded = true; + + // There is no "peek-lock" semantic. All dequeued messages are immediately deleted. + state.MessagesJson.Clear(); + + return (state.InstanceId, runtimeState, messages); + } + } + + public bool TryGetState(string instanceId, [NotNullWhen(true)] out OrchestrationState? statusRecord) + { + if (!this.store.TryGetValue(instanceId, out SerializedInstanceState? state)) + { + statusRecord = null; + return false; + } + + statusRecord = state.StatusRecordJson?.GetValue(); + return statusRecord != null; + } + + public void SaveState( + OrchestrationRuntimeState runtimeState, + OrchestrationState statusRecord, + IEnumerable newMessages) + { + static bool IsCompleted(OrchestrationRuntimeState runtimeState) => + runtimeState.OrchestrationStatus == OrchestrationStatus.Completed || + runtimeState.OrchestrationStatus == OrchestrationStatus.Failed || + runtimeState.OrchestrationStatus == OrchestrationStatus.Terminated || + runtimeState.OrchestrationStatus == OrchestrationStatus.Canceled; + + if (string.IsNullOrEmpty(runtimeState.OrchestrationInstance?.InstanceId)) + { + throw new ArgumentException("The provided runtime state doesn't contain instance ID information.", nameof(runtimeState)); + } + + string instanceId = runtimeState.OrchestrationInstance.InstanceId; + string executionId = runtimeState.OrchestrationInstance.ExecutionId; + SerializedInstanceState state = this.store.GetOrAdd( + instanceId, + _ => new SerializedInstanceState(instanceId, executionId)); + lock (state) + { + if (state.ExecutionId == null) + { + // This orchestration was started by a message without an execution ID. + state.ExecutionId = executionId; + } + else if (state.ExecutionId != executionId) + { + // This is a new generation (ContinueAsNew). Erase the old history. + state.HistoryEventsJson.Clear(); + state.ExecutionId = executionId; + } + + foreach (TaskMessage msg in newMessages) + { + this.AddMessage(msg); + } + + // Append to the orchestration history + foreach (HistoryEvent e in runtimeState.NewEvents) + { + state.HistoryEventsJson.Add(e); + } + + state.StatusRecordJson = JsonValue.Create(statusRecord); + state.IsCompleted = IsCompleted(runtimeState); + } + + // Notify any waiters of the orchestration completion + if (IsCompleted(runtimeState) && + this.waiters.TryRemove(statusRecord.OrchestrationInstance.InstanceId, out TaskCompletionSource? waiter)) + { + waiter.TrySetResult(statusRecord); + } + } + + public void AddMessage(TaskMessage message) + { + string instanceId = message.OrchestrationInstance.InstanceId; + string? executionId = message.OrchestrationInstance.ExecutionId; + + SerializedInstanceState state = this.store.GetOrAdd(instanceId, id => new SerializedInstanceState(id, executionId)); + lock (state) + { + if (message.Event is ExecutionStartedEvent startEvent) + { + OrchestrationState newStatusRecord = new() + { + OrchestrationInstance = startEvent.OrchestrationInstance, + CreatedTime = DateTime.UtcNow, + LastUpdatedTime = DateTime.UtcNow, + OrchestrationStatus = OrchestrationStatus.Pending, + Version = startEvent.Version, + Name = startEvent.Name, + Input = startEvent.Input, + ScheduledStartTime = startEvent.ScheduledStartTime, + Tags = startEvent.Tags, + }; + + state.StatusRecordJson = JsonValue.Create(newStatusRecord); + state.HistoryEventsJson.Clear(); + state.IsCompleted = false; + } + else if (state.IsCompleted) + { + // Drop the message since we're completed + // GOOD: The user-provided the instanceId + // logger.LogWarning( + // "Dropped {eventType} message for instance '{instanceId}' because the orchestration has already completed.", + // message.Event.EventType, + // instanceId); + return; + } + + if (TryGetScheduledTime(message, out TimeSpan delay)) + { + // Not ready for this message yet - delay the enqueue + _ = Task.Delay(delay).ContinueWith(t => this.AddMessage(message)); + return; + } + + state.MessagesJson.Add(message); + + if (!state.IsLoaded) + { + // The orchestration isn't running, so schedule it to run now. + // If it is running, it will be scheduled again automatically when it's released. + this.readyToRunQueue.Schedule(state); + } + } + } + + public void AbandonInstance(IEnumerable messagesToReturn) + { + foreach (var message in messagesToReturn) + { + this.AddMessage(message); + } + } + + public void ReleaseLock(string instanceId) + { + if (!this.store.TryGetValue(instanceId, out SerializedInstanceState? state) || !state.IsLoaded) + { + throw new InvalidOperationException($"Instance {instanceId} is not in the store or is not loaded!"); + } + + lock (state) + { + state.IsLoaded = false; + if (state.MessagesJson.Count > 0) + { + // More messages came in while we were running. Or, messages were abandoned. + // Put this back into the read-to-run queue! + this.readyToRunQueue.Schedule(state); + } + } + } + + public Task WaitForInstanceAsync(string instanceId, CancellationToken cancellationToken) + { + if (this.store.TryGetValue(instanceId, out SerializedInstanceState? state)) + { + lock (state) + { + OrchestrationState? statusRecord = state.StatusRecordJson?.GetValue(); + if (statusRecord != null) + { + if (statusRecord.OrchestrationStatus == OrchestrationStatus.Completed || + statusRecord.OrchestrationStatus == OrchestrationStatus.Failed || + statusRecord.OrchestrationStatus == OrchestrationStatus.Terminated) + { + // orchestration has already completed + return Task.FromResult(statusRecord); + } + } + + } + } + + // Caller will be notified when the instance completes. + // The ContinueWith is just to enable cancellation: https://stackoverflow.com/a/25652873/2069 + var tcs = this.waiters.GetOrAdd(instanceId, _ => new TaskCompletionSource()); + return tcs.Task.ContinueWith(t => t.GetAwaiter().GetResult(), cancellationToken); + } + + public OrchestrationQueryResult GetOrchestrationWithQuery(OrchestrationQuery query) + { + int startIndex = 0; + int counter = 0; + string? continuationToken = query.ContinuationToken; + if (continuationToken != null) + { + if (!Int32.TryParse(continuationToken, out startIndex)) + { + throw new InvalidOperationException($"{continuationToken} cannot be parsed to Int32"); + } + } + + counter = startIndex; + + List results = this.store + .Skip(startIndex) + .Where(item => + { + counter++; + OrchestrationState? statusRecord = item.Value.StatusRecordJson?.GetValue(); + if (statusRecord == null) return false; + if (query.CreatedTimeFrom != null && query.CreatedTimeFrom > statusRecord.CreatedTime) return false; + if (query.CreatedTimeTo != null && query.CreatedTimeTo < statusRecord.CreatedTime) return false; + if (query.RuntimeStatus != null && query.RuntimeStatus.Any() && !query.RuntimeStatus.Contains(statusRecord.OrchestrationStatus)) return false; + if (query.InstanceIdPrefix != null && !statusRecord.OrchestrationInstance.InstanceId.StartsWith(query.InstanceIdPrefix)) return false; + return true; + }) + .Take(query.PageSize) + .Select(item => item.Value.StatusRecordJson!.GetValue()) + .ToList(); + + string? token = null; + if (results.Count == query.PageSize) + { + token = counter.ToString(); + } + return new OrchestrationQueryResult(results, token); + } + + public PurgeResult PurgeInstanceState(string instanceId) + { + if (instanceId != null && this.store.TryGetValue(instanceId, out SerializedInstanceState? state) && state.IsCompleted) + { + this.store.TryRemove(instanceId, out SerializedInstanceState? removedState); + if (removedState != null) + { + return new PurgeResult(1); + } + } + return new PurgeResult(0); + } + + public PurgeResult PurgeInstanceState(PurgeInstanceFilter purgeInstanceFilter) + { + int counter = 0; + + List filteredInstanceIds = this.store + .Where(item => + { + OrchestrationState? statusRecord = item.Value.StatusRecordJson?.GetValue(); + if (statusRecord == null) return false; + if (purgeInstanceFilter.CreatedTimeFrom > statusRecord.CreatedTime) return false; + if (purgeInstanceFilter.CreatedTimeTo != null && purgeInstanceFilter.CreatedTimeTo < statusRecord.CreatedTime) return false; + if (purgeInstanceFilter.RuntimeStatus != null && purgeInstanceFilter.RuntimeStatus.Any() && !purgeInstanceFilter.RuntimeStatus.Contains(statusRecord.OrchestrationStatus)) return false; + return true; + }) + .Select(item => item.Key) + .ToList(); + + foreach (string instanceId in filteredInstanceIds) + { + this.store.TryRemove(instanceId, out SerializedInstanceState? removedState); + if (removedState != null) + { + counter++; + } + } + + return new PurgeResult(counter); + } + + class ReadyToRunQueue + { + readonly Channel readyToRunQueue = Channel.CreateUnbounded(); + readonly Dictionary readyInstances = new(StringComparer.OrdinalIgnoreCase); + + public void Reset() + { + this.readyInstances.Clear(); + } + + public async ValueTask TakeNextAsync(CancellationToken ct) + { + while (true) + { + SerializedInstanceState state = await this.readyToRunQueue.Reader.ReadAsync(ct); + lock (state) + { + if (this.readyInstances.Remove(state.InstanceId)) + { + if (state.IsLoaded) + { + throw new InvalidOperationException("Should never load state that is already loaded."); + } + + state.IsLoaded = true; + return state; + } + } + } + } + + public void Schedule(SerializedInstanceState state) + { + // TODO: There is a race condition here. If another thread is calling TakeNextAsync + // and removed the queue item before updating the dictionary, then we'll fail + // to update the readyToRunQueue and the orchestration will get stuck. + if (this.readyInstances.TryAdd(state.InstanceId, state)) + { + this.readyToRunQueue.Writer.TryWrite(state); + } + } + } + + class SerializedInstanceState + { + public SerializedInstanceState(string instanceId, string? executionId) + { + this.InstanceId = instanceId; + this.ExecutionId = executionId; + } + + public string InstanceId { get; } + public string? ExecutionId { get; internal set; } + public JsonValue? StatusRecordJson { get; set; } + public JsonArray HistoryEventsJson { get; } = new JsonArray(); + public JsonArray MessagesJson { get; } = new JsonArray(); + + internal bool IsLoaded { get; set; } + internal bool IsCompleted { get; set; } + } + } +} diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/Logs.cs b/test/Grpc.IntegrationTests/GrpcSidecar/Logs.cs new file mode 100644 index 000000000..3a38bc7c6 --- /dev/null +++ b/test/Grpc.IntegrationTests/GrpcSidecar/Logs.cs @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core; +using DurableTask.Core.Command; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DurableTask.Sidecar +{ + static partial class Logs + { + [LoggerMessage( + EventId = 5, + Level = LogLevel.Information, + Message = "Waiting for a remote client to connect to this server. Total wait time: {totalWaitTime:c}")] + public static partial void WaitingForClientConnection( + this ILogger logger, + TimeSpan totalWaitTime); + + [LoggerMessage( + EventId = 6, + Level = LogLevel.Information, + Message = "Received work-item connection from {address}. Client connection deadline = {deadline:s}.")] + public static partial void ClientConnected( + this ILogger logger, + string address, + DateTime deadline); + + [LoggerMessage( + EventId = 7, + Level = LogLevel.Information, + Message = "Client at {address} has disconnected. No further work-items will be processed until a new connection is established.")] + public static partial void ClientDisconnected( + this ILogger logger, + string address); + + [LoggerMessage( + EventId = 22, + Level = LogLevel.Information, + Message = "{dispatcher}: Shutting down, waiting for {currentWorkItemCount} active work-items to complete.")] + public static partial void DispatcherStopping( + this ILogger logger, + string dispatcher, + int currentWorkItemCount); + + [LoggerMessage( + EventId = 23, + Level = LogLevel.Trace, + Message = "{dispatcher}: Fetching next work item. Current active work-items: {currentWorkItemCount}/{maxWorkItemCount}.")] + public static partial void FetchWorkItemStarting( + this ILogger logger, + string dispatcher, + int currentWorkItemCount, + int maxWorkItemCount); + + [LoggerMessage( + EventId = 24, + Level = LogLevel.Trace, + Message = "{dispatcher}: Fetched next work item '{workItemId}' after {latencyMs}ms. Current active work-items: {currentWorkItemCount}/{maxWorkItemCount}.")] + public static partial void FetchWorkItemCompleted( + this ILogger logger, + string dispatcher, + string workItemId, + long latencyMs, + int currentWorkItemCount, + int maxWorkItemCount); + + [LoggerMessage( + EventId = 25, + Level = LogLevel.Error, + Message = "{dispatcher}: Unexpected {action} failure for work-item '{workItemId}': {details}")] + public static partial void DispatchWorkItemFailure( + this ILogger logger, + string dispatcher, + string action, + string workItemId, + string details); + + [LoggerMessage( + EventId = 26, + Level = LogLevel.Information, + Message = "{dispatcher}: Work-item fetching is paused: {details}. Current active work-item count: {currentWorkItemCount}/{maxWorkItemCount}.")] + public static partial void FetchingThrottled( + this ILogger logger, + string dispatcher, + string details, + int currentWorkItemCount, + int maxWorkItemCount); + + [LoggerMessage( + EventId = 49, + Level = LogLevel.Information, + Message = "{instanceId}: Orchestrator '{name}' completed with a {runtimeStatus} status and {sizeInBytes} bytes of output.")] + public static partial void OrchestratorCompleted( + this ILogger logger, + string instanceId, + string name, + OrchestrationStatus runtimeStatus, + int sizeInBytes); + + [LoggerMessage( + EventId = 51, + Level = LogLevel.Debug, + Message = "{instanceId}: Preparing to execute orchestrator '{name}' with {eventCount} new events: {newEvents}")] + public static partial void OrchestratorExecuting( + this ILogger logger, + string instanceId, + string name, + int eventCount, + string newEvents); + + [LoggerMessage( + EventId = 55, + Level = LogLevel.Warning, + Message = "{instanceId}: Ignoring unknown orchestrator action '{action}'.")] + public static partial void IgnoringUnknownOrchestratorAction( + this ILogger logger, + string instanceId, + OrchestratorActionType action); + } +} \ No newline at end of file diff --git a/test/Grpc.IntegrationTests/GrpcSidecar/Utils.cs b/test/Grpc.IntegrationTests/GrpcSidecar/Utils.cs new file mode 100644 index 000000000..4be9aa6a3 --- /dev/null +++ b/test/Grpc.IntegrationTests/GrpcSidecar/Utils.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core.History; + +namespace Microsoft.DurableTask.Sidecar; + +internal static class Utils +{ + public static bool TryGetTaskScheduledId(HistoryEvent historyEvent, out int taskScheduledId) + { + switch (historyEvent.EventType) + { + case EventType.TaskCompleted: + taskScheduledId = ((TaskCompletedEvent)historyEvent).TaskScheduledId; + return true; + case EventType.TaskFailed: + taskScheduledId = ((TaskFailedEvent)historyEvent).TaskScheduledId; + return true; + case EventType.SubOrchestrationInstanceCompleted: + taskScheduledId = ((SubOrchestrationInstanceCompletedEvent)historyEvent).TaskScheduledId; + return true; + case EventType.SubOrchestrationInstanceFailed: + taskScheduledId = ((SubOrchestrationInstanceFailedEvent)historyEvent).TaskScheduledId; + return true; + case EventType.TimerFired: + taskScheduledId = ((TimerFiredEvent)historyEvent).TimerId; + return true; + case EventType.ExecutionStarted: + var parentInstance = ((ExecutionStartedEvent)historyEvent).ParentInstance; + if (parentInstance != null) + { + // taskId that scheduled a sub-orchestration + taskScheduledId = parentInstance.TaskScheduleId; + return true; + } + else + { + taskScheduledId = -1; + return false; + } + default: + taskScheduledId = -1; + return false; + } + } +} diff --git a/test/Grpc.IntegrationTests/OrchestrationPatterns.cs b/test/Grpc.IntegrationTests/OrchestrationPatterns.cs index 3cbc0d793..5c33f5ba4 100644 --- a/test/Grpc.IntegrationTests/OrchestrationPatterns.cs +++ b/test/Grpc.IntegrationTests/OrchestrationPatterns.cs @@ -1,597 +1,630 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Text.Json; -using System.Text.Json.Nodes; -using Microsoft.DurableTask.Client; -using Microsoft.DurableTask.Tests.Logging; -using Microsoft.DurableTask.Worker; -using Microsoft.Extensions.DependencyInjection; -using Xunit.Abstractions; - -namespace Microsoft.DurableTask.Grpc.Tests; - -public class OrchestrationPatterns : IntegrationTestBase -{ - public OrchestrationPatterns(ITestOutputHelper output, GrpcSidecarFixture sidecarFixture) - : base(output, sidecarFixture) - { } - - [Fact] - public async Task EmptyOrchestration() - { - TaskName orchestratorName = nameof(EmptyOrchestration); - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, ctx => Task.FromResult(null))); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); - OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, this.TimeoutToken); - - Assert.NotNull(metadata); - Assert.Equal(instanceId, metadata.InstanceId); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - } - - [Fact] - public async Task SingleTimer() - { - TaskName orchestratorName = nameof(SingleTimer); - TimeSpan delay = TimeSpan.FromSeconds(3); - - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks.AddOrchestratorFunc( - orchestratorName, ctx => ctx.CreateTimer(delay, CancellationToken.None))); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); - OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, this.TimeoutToken); - - Assert.NotNull(metadata); - Assert.Equal(instanceId, metadata.InstanceId); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - - // Verify that the delay actually happened with a 1 second variation - Assert.True(metadata.CreatedAt.Add(delay) <= metadata.LastUpdatedAt.AddSeconds(1)); - } - - [Fact] - public async Task LongTimer() - { - TaskName orchestratorName = nameof(SingleTimer); - TimeSpan delay = TimeSpan.FromSeconds(7); - TimeSpan timerInterval = TimeSpan.FromSeconds(3); - const int ExpectedTimers = 3; // two for 3 seconds and one for 1 second - - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.Configure(opt => opt.MaximumTimerInterval = timerInterval); - b.AddTasks(tasks => tasks.AddOrchestratorFunc( - orchestratorName, ctx => ctx.CreateTimer(delay, CancellationToken.None))); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); - OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync(instanceId, this.TimeoutToken); - - Assert.NotNull(metadata); - Assert.Equal(instanceId, metadata.InstanceId); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - - // Verify that the delay actually happened - Assert.True(metadata.CreatedAt.Add(delay) <= metadata.LastUpdatedAt.AddSeconds(1)); - - // Verify that the correct number of timers were created - IReadOnlyCollection logs = this.GetLogs(); - int timersCreated = logs.Count(log => log.Message.Contains("CreateTimer")); - Assert.Equal(ExpectedTimers, timersCreated); - } - - [Fact] - public async Task IsReplaying() - { - TaskName orchestratorName = nameof(IsReplaying); - - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx => - { - var list = new List { ctx.IsReplaying }; - await ctx.CreateTimer(TimeSpan.Zero, CancellationToken.None); - list.Add(ctx.IsReplaying); - await ctx.CreateTimer(TimeSpan.Zero, CancellationToken.None); - list.Add(ctx.IsReplaying); - return list; - })); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); - OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - List? results = metadata.ReadOutputAs>(); - Assert.NotNull(results); - Assert.Equal(3, results!.Count); - Assert.True(results[0]); - Assert.True(results[1]); - Assert.False(results[2]); - } - - [Fact] - public async Task CurrentDateTimeUtc() - { - TaskName orchestratorName = nameof(CurrentDateTimeUtc); - TaskName echoActivityName = "Echo"; - - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks - .AddOrchestratorFunc(orchestratorName, async ctx => - { - DateTime currentDate1 = ctx.CurrentUtcDateTime; - DateTime originalDate1 = await ctx.CallActivityAsync(echoActivityName, currentDate1); - if (currentDate1 != originalDate1) - { - return false; - } - - DateTime currentDate2 = ctx.CurrentUtcDateTime; - DateTime originalDate2 = await ctx.CallActivityAsync(echoActivityName, currentDate2); - if (currentDate2 != originalDate2) - { - return false; - } - - return currentDate1 != currentDate2; - }) - .AddActivityFunc(echoActivityName, (ctx, input) => input)); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); - OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - Assert.NotNull(metadata); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - Assert.True(metadata.ReadOutputAs()); - } - - [Fact] - public async Task SingleActivity() - { - TaskName orchestratorName = nameof(SingleActivity); - TaskName sayHelloActivityName = "SayHello"; - - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks - .AddOrchestratorFunc( - orchestratorName, (ctx, input) => ctx.CallActivityAsync(sayHelloActivityName, input)) - .AddActivityFunc(sayHelloActivityName, (ctx, name) => $"Hello, {name}!")); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName, input: "World"); - OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - Assert.NotNull(metadata); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - Assert.Equal("Hello, World!", metadata.ReadOutputAs()); - } - - [Fact] - public async Task SingleActivity_Async() - { - TaskName orchestratorName = nameof(SingleActivity); - TaskName sayHelloActivityName = "SayHello"; - - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks - .AddOrchestratorFunc( - orchestratorName, (ctx, input) => ctx.CallActivityAsync(sayHelloActivityName, input)) - .AddActivityFunc( - sayHelloActivityName, async (ctx, name) => await Task.FromResult($"Hello, {name}!"))); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName, input: "World"); - OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - Assert.NotNull(metadata); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - Assert.Equal("Hello, World!", metadata.ReadOutputAs()); - } - - [Fact] - public async Task ActivityChain() - { - TaskName orchestratorName = nameof(ActivityChain); - TaskName plusOneActivityName = "PlusOne"; - - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks - .AddOrchestratorFunc(orchestratorName, async ctx => - { - int value = 0; - for (int i = 0; i < 10; i++) - { - value = await ctx.CallActivityAsync(plusOneActivityName, value); - } - - return value; - }) - .AddActivityFunc(plusOneActivityName, (ctx, input) => input + 1)); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName, input: "World"); - OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - Assert.NotNull(metadata); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - Assert.Equal(10, metadata.ReadOutputAs()); - } - - [Fact] - public async Task ActivityFanOut() - { - TaskName orchestratorName = nameof(ActivityFanOut); - TaskName toStringActivity = "ToString"; - - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks - .AddOrchestratorFunc(orchestratorName, async ctx => - { - var tasks = new List>(); - for (int i = 0; i < 10; i++) - { - tasks.Add(ctx.CallActivityAsync(toStringActivity, i)); - } - - string[] results = await Task.WhenAll(tasks); - Array.Sort(results); - Array.Reverse(results); - return results; - }) - .AddActivityFunc(toStringActivity, (ctx, input) => input.ToString())); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); - OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - Assert.NotNull(metadata); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - - string[] expected = new[] { "9", "8", "7", "6", "5", "4", "3", "2", "1", "0" }; - Assert.Equal(expected, metadata.ReadOutputAs()); - } - - [Theory] - [InlineData(1)] - [InlineData(100)] - public async Task ExternalEvents(int eventCount) - { - TaskName orchestratorName = nameof(ExternalEvents); - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx => - { - List events = new(); - for (int i = 0; i < eventCount; i++) - { - events.Add(await ctx.WaitForExternalEvent($"Event{i}")); - } - - return events; - })); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); - - // To ensure consistency, wait for the instance to start before sending the events - OrchestrationMetadata metadata = await server.Client.WaitForInstanceStartAsync( - instanceId, - this.TimeoutToken); - - // Send events one-at-a-time to that we can better ensure ordered processing. - for (int i = 0; i < eventCount; i++) - { - await server.Client.RaiseEventAsync(metadata.InstanceId, $"Event{i}", eventPayload: i); - } - - // Once the orchestration receives all the events it is expecting, it should complete. - metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - Assert.NotNull(metadata); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - - int[] expected = Enumerable.Range(0, eventCount).ToArray(); - Assert.Equal(expected, metadata.ReadOutputAs()); - } - - [Theory] - [InlineData(1)] - [InlineData(5)] - public async Task ExternalEventsInParallel(int eventCount) - { - TaskName orchestratorName = nameof(ExternalEvents); - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx => - { - List> events = new(); - for (int i = 0; i < eventCount; i++) - { - events.Add(ctx.WaitForExternalEvent("Event")); - } - - return await Task.WhenAll(events); - })); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); - - // To ensure consistency, wait for the instance to start before sending the events - OrchestrationMetadata metadata = await server.Client.WaitForInstanceStartAsync( - instanceId, - this.TimeoutToken); - - // Send events one-at-a-time to that we can better ensure ordered processing. - for (int i = 0; i < eventCount; i++) - { - await server.Client.RaiseEventAsync(metadata.InstanceId, "Event", eventPayload: i); - } - - // Once the orchestration receives all the events it is expecting, it should complete. - metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true); - Assert.NotNull(metadata); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - - int[] expected = Enumerable.Range(0, eventCount).ToArray(); - Assert.Equal(expected, metadata.ReadOutputAs()); - } - - [Fact] - public async Task Termination() - { - TaskName orchestrationName = nameof(Termination); - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks.AddOrchestratorFunc( - orchestrationName, ctx => ctx.CreateTimer(TimeSpan.FromSeconds(3), CancellationToken.None))); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestrationName); - OrchestrationMetadata metadata = await server.Client.WaitForInstanceStartAsync(instanceId, this.TimeoutToken); - - var expectedOutput = new { quote = "I'll be back." }; - await server.Client.TerminateInstanceAsync(instanceId, expectedOutput); - - metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - Assert.NotNull(metadata); - Assert.Equal(instanceId, metadata.InstanceId); - Assert.Equal(OrchestrationRuntimeStatus.Terminated, metadata.RuntimeStatus); - - JsonElement actualOutput = metadata.ReadOutputAs(); - string? actualQuote = actualOutput.GetProperty("quote").GetString(); - Assert.NotNull(actualQuote); - Assert.Equal(expectedOutput.quote, actualQuote); - } - - [Fact] - public async Task ContinueAsNew() - { - TaskName orchestratorName = nameof(ContinueAsNew); - - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async (ctx, input) => - { - if (input < 10) - { - await ctx.CreateTimer(TimeSpan.Zero, CancellationToken.None); - ctx.ContinueAsNew(input + 1); - } - - return input; - })); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); - OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - Assert.NotNull(metadata); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - Assert.Equal(10, metadata.ReadOutputAs()); - } - - [Fact] - public async Task SubOrchestration() - { - TaskName orchestratorName = nameof(SubOrchestration); - - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async (ctx, input) => - { - int result = 5; - if (input < 3) - { - // recursively call this same orchestrator - result += await ctx.CallSubOrchestratorAsync(orchestratorName, input: input + 1); - } - - return result; - })); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName, input: 1); - OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - Assert.NotNull(metadata); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - Assert.Equal(15, metadata.ReadOutputAs()); - } - - [Fact] - public async Task SetCustomStatus() - { - TaskName orchestratorName = nameof(SetCustomStatus); - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx => - { - ctx.SetCustomStatus("Started!"); - - object customStatus = await ctx.WaitForExternalEvent("StatusEvent"); - ctx.SetCustomStatus(customStatus); - })); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); - - // To ensure consistency, wait for the instance to start before sending the events - OrchestrationMetadata metadata = await server.Client.WaitForInstanceStartAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - Assert.NotNull(metadata); - Assert.Equal("Started!", metadata.ReadCustomStatusAs()); - - // Send a tuple payload, which will be used as the custom status - (string, int) eventPayload = ("Hello", 42); - await server.Client.RaiseEventAsync( - metadata.InstanceId, - eventName: "StatusEvent", - eventPayload); - - // Once the orchestration receives all the events it is expecting, it should complete. - metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - Assert.NotNull(metadata); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - Assert.Equal(eventPayload, metadata.ReadCustomStatusAs<(string, int)>()); - } - - [Fact] - public async Task NewGuidTest() - { - TaskName orchestratorName = nameof(ContinueAsNew); - TaskName echoActivityName = "Echo"; - - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks - .AddOrchestratorFunc(orchestratorName, async (ctx, input) => - { - // Test 1: Ensure two consecutively created GUIDs are unique - Guid currentGuid0 = ctx.NewGuid(); - Guid currentGuid1 = ctx.NewGuid(); - if (currentGuid0 == currentGuid1) - { - return false; - } - - // Test 2: Ensure that the same GUID values are created on each replay - Guid originalGuid1 = await ctx.CallActivityAsync(echoActivityName, currentGuid1); - if (currentGuid1 != originalGuid1) - { - return false; - } - - // Test 3: Ensure that the same GUID values are created on each replay even after an await - Guid currentGuid2 = ctx.NewGuid(); - Guid originalGuid2 = await ctx.CallActivityAsync(echoActivityName, currentGuid2); - if (currentGuid2 != originalGuid2) - { - return false; - } - - // Test 4: Finish confirming that every generated GUID is unique - return currentGuid1 != currentGuid2; - }) - .AddActivityFunc(echoActivityName, (ctx, input) => input)); - }); - - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); - OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - Assert.NotNull(metadata); - Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); - Assert.True(metadata.ReadOutputAs()); - } - - [Fact] - public async Task SpecialSerialization() - { - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks - .AddOrchestratorFunc("SpecialSerialization_Orchestration", (ctx, input) => - { - if (input is null) - { - throw new ArgumentNullException(nameof(input)); - } - - return ctx.CallActivityAsync("SpecialSerialization_Activity", input); - }) - .AddActivityFunc("SpecialSerialization_Activity", (ctx, input) => - { - if (input is not null) - { - input["newProperty"] = "new value"; - } - - return Task.FromResult(input); - })); - }); - - JsonNode input = new JsonObject() { ["originalProperty"] = "original value" }; - string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync( - "SpecialSerialization_Orchestration", input: input); - OrchestrationMetadata result = await server.Client.WaitForInstanceCompletionAsync( - instanceId, getInputsAndOutputs: true, this.TimeoutToken); - JsonNode? output = result.ReadOutputAs(); - - Assert.NotNull(output); - Assert.Equal("original value", output?["originalProperty"]?.ToString()); - Assert.Equal("new value", output?["newProperty"]?.ToString()); - } - - // TODO: Additional versioning tests - [Fact] - public async Task OrchestrationVersionPassedThroughContext() - { - var version = "0.1"; - await using HostTestLifetime server = await this.StartWorkerAsync(b => - { - b.AddTasks(tasks => tasks - .AddOrchestratorFunc("Versioned_Orchestration", (ctx, input) => - { - return ctx.CallActivityAsync("Versioned_Activity", ctx.Version); - }) - .AddActivityFunc("Versioned_Activity", (ctx, input) => - { - return $"Orchestration version: {input}"; - })); - }, c => - { - c.UseDefaultVersion(version); - }); - - var instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync("Versioned_Orchestration", input: string.Empty); - var result = await server.Client.WaitForInstanceCompletionAsync(instanceId, getInputsAndOutputs: true, this.TimeoutToken); - var output = result.ReadOutputAs(); - - Assert.NotNull(output); - Assert.Equal(output, $"Orchestration version: {version}"); - - } - - // TODO: Test for multiple external events with the same name - // TODO: Test for ContinueAsNew with external events that carry over - // TODO: Test for catching activity exceptions of specific types -} +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Tests.Logging; +using Microsoft.DurableTask.Worker; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Xunit.Abstractions; + +namespace Microsoft.DurableTask.Grpc.Tests; + +public class OrchestrationPatterns : IntegrationTestBase +{ + public OrchestrationPatterns(ITestOutputHelper output, GrpcSidecarFixture sidecarFixture) + : base(output, sidecarFixture) + { } + + [Fact] + public async Task EmptyOrchestration() + { + TaskName orchestratorName = nameof(EmptyOrchestration); + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, ctx => Task.FromResult(null))); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, this.TimeoutToken); + + Assert.NotNull(metadata); + Assert.Equal(instanceId, metadata.InstanceId); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + } + + [Fact] + public async Task ScheduleOrchesrationWithTags() + { + TaskName orchestratorName = nameof(EmptyOrchestration); + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, ctx => Task.FromResult(null))); + }); + + // Schedule a new orchestration instance with tags + StartOrchestrationOptions options = new() + { + Tags = new Dictionary + { + { "tag1", "value1" }, + { "tag2", "value2" } + } + }; + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName, options); + + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, this.TimeoutToken); + + Assert.NotNull(metadata); + Assert.Equal(instanceId, metadata.InstanceId); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.NotNull(metadata.Tags); + Assert.Equal(2, metadata.Tags.Count); + Assert.Equal("value1", metadata.Tags["tag1"]); + Assert.Equal("value2", metadata.Tags["tag2"]); + } + + [Fact] + public async Task SingleTimer() + { + TaskName orchestratorName = nameof(SingleTimer); + TimeSpan delay = TimeSpan.FromSeconds(3); + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc( + orchestratorName, ctx => ctx.CreateTimer(delay, CancellationToken.None))); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, this.TimeoutToken); + + Assert.NotNull(metadata); + Assert.Equal(instanceId, metadata.InstanceId); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + + // Verify that the delay actually happened with a 1 second variation + Assert.True(metadata.CreatedAt.Add(delay) <= metadata.LastUpdatedAt.AddSeconds(1)); + } + + [Fact] + public async Task LongTimer() + { + TaskName orchestratorName = nameof(SingleTimer); + TimeSpan delay = TimeSpan.FromSeconds(7); + TimeSpan timerInterval = TimeSpan.FromSeconds(3); + const int ExpectedTimers = 3; // two for 3 seconds and one for 1 second + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.Configure(opt => opt.MaximumTimerInterval = timerInterval); + b.AddTasks(tasks => tasks.AddOrchestratorFunc( + orchestratorName, ctx => ctx.CreateTimer(delay, CancellationToken.None))); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync(instanceId, this.TimeoutToken); + + Assert.NotNull(metadata); + Assert.Equal(instanceId, metadata.InstanceId); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + + // Verify that the delay actually happened + Assert.True(metadata.CreatedAt.Add(delay) <= metadata.LastUpdatedAt.AddSeconds(1)); + + // Verify that the correct number of timers were created + IReadOnlyCollection logs = this.GetLogs(); + int timersCreated = logs.Count(log => log.Message.Contains("CreateTimer")); + Assert.Equal(ExpectedTimers, timersCreated); + } + + [Fact] + public async Task IsReplaying() + { + TaskName orchestratorName = nameof(IsReplaying); + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx => + { + var list = new List { ctx.IsReplaying }; + await ctx.CreateTimer(TimeSpan.Zero, CancellationToken.None); + list.Add(ctx.IsReplaying); + await ctx.CreateTimer(TimeSpan.Zero, CancellationToken.None); + list.Add(ctx.IsReplaying); + return list; + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + List? results = metadata.ReadOutputAs>(); + Assert.NotNull(results); + Assert.Equal(3, results!.Count); + Assert.True(results[0]); + Assert.True(results[1]); + Assert.False(results[2]); + } + + [Fact] + public async Task CurrentDateTimeUtc() + { + TaskName orchestratorName = nameof(CurrentDateTimeUtc); + TaskName echoActivityName = "Echo"; + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc(orchestratorName, async ctx => + { + DateTime currentDate1 = ctx.CurrentUtcDateTime; + DateTime originalDate1 = await ctx.CallActivityAsync(echoActivityName, currentDate1); + if (currentDate1 != originalDate1) + { + return false; + } + + DateTime currentDate2 = ctx.CurrentUtcDateTime; + DateTime originalDate2 = await ctx.CallActivityAsync(echoActivityName, currentDate2); + if (currentDate2 != originalDate2) + { + return false; + } + + return currentDate1 != currentDate2; + }) + .AddActivityFunc(echoActivityName, (ctx, input) => input)); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.True(metadata.ReadOutputAs()); + } + + [Fact] + public async Task SingleActivity() + { + TaskName orchestratorName = nameof(SingleActivity); + TaskName sayHelloActivityName = "SayHello"; + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc( + orchestratorName, (ctx, input) => ctx.CallActivityAsync(sayHelloActivityName, input)) + .AddActivityFunc(sayHelloActivityName, (ctx, name) => $"Hello, {name}!")); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName, input: "World"); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("Hello, World!", metadata.ReadOutputAs()); + } + + [Fact] + public async Task SingleActivity_Async() + { + TaskName orchestratorName = nameof(SingleActivity); + TaskName sayHelloActivityName = "SayHello"; + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc( + orchestratorName, (ctx, input) => ctx.CallActivityAsync(sayHelloActivityName, input)) + .AddActivityFunc( + sayHelloActivityName, async (ctx, name) => await Task.FromResult($"Hello, {name}!"))); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName, input: "World"); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("Hello, World!", metadata.ReadOutputAs()); + } + + [Fact] + public async Task ActivityChain() + { + TaskName orchestratorName = nameof(ActivityChain); + TaskName plusOneActivityName = "PlusOne"; + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc(orchestratorName, async ctx => + { + int value = 0; + for (int i = 0; i < 10; i++) + { + value = await ctx.CallActivityAsync(plusOneActivityName, value); + } + + return value; + }) + .AddActivityFunc(plusOneActivityName, (ctx, input) => input + 1)); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName, input: "World"); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal(10, metadata.ReadOutputAs()); + } + + [Fact] + public async Task ActivityFanOut() + { + TaskName orchestratorName = nameof(ActivityFanOut); + TaskName toStringActivity = "ToString"; + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc(orchestratorName, async ctx => + { + var tasks = new List>(); + for (int i = 0; i < 10; i++) + { + tasks.Add(ctx.CallActivityAsync(toStringActivity, i)); + } + + string[] results = await Task.WhenAll(tasks); + Array.Sort(results); + Array.Reverse(results); + return results; + }) + .AddActivityFunc(toStringActivity, (ctx, input) => input.ToString())); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + + string[] expected = new[] { "9", "8", "7", "6", "5", "4", "3", "2", "1", "0" }; + Assert.Equal(expected, metadata.ReadOutputAs()); + } + + [Theory] + [InlineData(1)] + [InlineData(100)] + public async Task ExternalEvents(int eventCount) + { + TaskName orchestratorName = nameof(ExternalEvents); + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx => + { + List events = new(); + for (int i = 0; i < eventCount; i++) + { + events.Add(await ctx.WaitForExternalEvent($"Event{i}")); + } + + return events; + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + + // To ensure consistency, wait for the instance to start before sending the events + OrchestrationMetadata metadata = await server.Client.WaitForInstanceStartAsync( + instanceId, + this.TimeoutToken); + + // Send events one-at-a-time to that we can better ensure ordered processing. + for (int i = 0; i < eventCount; i++) + { + await server.Client.RaiseEventAsync(metadata.InstanceId, $"Event{i}", eventPayload: i); + } + + // Once the orchestration receives all the events it is expecting, it should complete. + metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + + int[] expected = Enumerable.Range(0, eventCount).ToArray(); + Assert.Equal(expected, metadata.ReadOutputAs()); + } + + [Theory] + [InlineData(1)] + [InlineData(5)] + public async Task ExternalEventsInParallel(int eventCount) + { + TaskName orchestratorName = nameof(ExternalEvents); + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx => + { + List> events = new(); + for (int i = 0; i < eventCount; i++) + { + events.Add(ctx.WaitForExternalEvent("Event")); + } + + return await Task.WhenAll(events); + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + + // To ensure consistency, wait for the instance to start before sending the events + OrchestrationMetadata metadata = await server.Client.WaitForInstanceStartAsync( + instanceId, + this.TimeoutToken); + + // Send events one-at-a-time to that we can better ensure ordered processing. + for (int i = 0; i < eventCount; i++) + { + await server.Client.RaiseEventAsync(metadata.InstanceId, "Event", eventPayload: i); + } + + // Once the orchestration receives all the events it is expecting, it should complete. + metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + + int[] expected = Enumerable.Range(0, eventCount).ToArray(); + Assert.Equal(expected, metadata.ReadOutputAs()); + } + + [Fact] + public async Task Termination() + { + TaskName orchestrationName = nameof(Termination); + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc( + orchestrationName, ctx => ctx.CreateTimer(TimeSpan.FromSeconds(3), CancellationToken.None))); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestrationName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceStartAsync(instanceId, this.TimeoutToken); + + var expectedOutput = new { quote = "I'll be back." }; + await server.Client.TerminateInstanceAsync(instanceId, expectedOutput); + + metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(instanceId, metadata.InstanceId); + Assert.Equal(OrchestrationRuntimeStatus.Terminated, metadata.RuntimeStatus); + + JsonElement actualOutput = metadata.ReadOutputAs(); + string? actualQuote = actualOutput.GetProperty("quote").GetString(); + Assert.NotNull(actualQuote); + Assert.Equal(expectedOutput.quote, actualQuote); + } + + [Fact] + public async Task ContinueAsNew() + { + TaskName orchestratorName = nameof(ContinueAsNew); + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async (ctx, input) => + { + if (input < 10) + { + await ctx.CreateTimer(TimeSpan.Zero, CancellationToken.None); + ctx.ContinueAsNew(input + 1); + } + + return input; + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal(10, metadata.ReadOutputAs()); + } + + [Fact] + public async Task SubOrchestration() + { + TaskName orchestratorName = nameof(SubOrchestration); + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async (ctx, input) => + { + int result = 5; + if (input < 3) + { + // recursively call this same orchestrator + result += await ctx.CallSubOrchestratorAsync(orchestratorName, input: input + 1); + } + + return result; + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName, input: 1); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal(15, metadata.ReadOutputAs()); + } + + [Fact] + public async Task SetCustomStatus() + { + TaskName orchestratorName = nameof(SetCustomStatus); + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks.AddOrchestratorFunc(orchestratorName, async ctx => + { + ctx.SetCustomStatus("Started!"); + + object customStatus = await ctx.WaitForExternalEvent("StatusEvent"); + ctx.SetCustomStatus(customStatus); + })); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + + // To ensure consistency, wait for the instance to start before sending the events + OrchestrationMetadata metadata = await server.Client.WaitForInstanceStartAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal("Started!", metadata.ReadCustomStatusAs()); + + // Send a tuple payload, which will be used as the custom status + (string, int) eventPayload = ("Hello", 42); + await server.Client.RaiseEventAsync( + metadata.InstanceId, + eventName: "StatusEvent", + eventPayload); + + // Once the orchestration receives all the events it is expecting, it should complete. + metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal(eventPayload, metadata.ReadCustomStatusAs<(string, int)>()); + } + + [Fact] + public async Task NewGuidTest() + { + TaskName orchestratorName = nameof(ContinueAsNew); + TaskName echoActivityName = "Echo"; + + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc(orchestratorName, async (ctx, input) => + { + // Test 1: Ensure two consecutively created GUIDs are unique + Guid currentGuid0 = ctx.NewGuid(); + Guid currentGuid1 = ctx.NewGuid(); + if (currentGuid0 == currentGuid1) + { + return false; + } + + // Test 2: Ensure that the same GUID values are created on each replay + Guid originalGuid1 = await ctx.CallActivityAsync(echoActivityName, currentGuid1); + if (currentGuid1 != originalGuid1) + { + return false; + } + + // Test 3: Ensure that the same GUID values are created on each replay even after an await + Guid currentGuid2 = ctx.NewGuid(); + Guid originalGuid2 = await ctx.CallActivityAsync(echoActivityName, currentGuid2); + if (currentGuid2 != originalGuid2) + { + return false; + } + + // Test 4: Finish confirming that every generated GUID is unique + return currentGuid1 != currentGuid2; + }) + .AddActivityFunc(echoActivityName, (ctx, input) => input)); + }); + + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync(orchestratorName); + OrchestrationMetadata metadata = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + Assert.NotNull(metadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.True(metadata.ReadOutputAs()); + } + + [Fact] + public async Task SpecialSerialization() + { + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc("SpecialSerialization_Orchestration", (ctx, input) => + { + if (input is null) + { + throw new ArgumentNullException(nameof(input)); + } + + return ctx.CallActivityAsync("SpecialSerialization_Activity", input); + }) + .AddActivityFunc("SpecialSerialization_Activity", (ctx, input) => + { + if (input is not null) + { + input["newProperty"] = "new value"; + } + + return Task.FromResult(input); + })); + }); + + JsonNode input = new JsonObject() { ["originalProperty"] = "original value" }; + string instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync( + "SpecialSerialization_Orchestration", input: input); + OrchestrationMetadata result = await server.Client.WaitForInstanceCompletionAsync( + instanceId, getInputsAndOutputs: true, this.TimeoutToken); + JsonNode? output = result.ReadOutputAs(); + + Assert.NotNull(output); + Assert.Equal("original value", output?["originalProperty"]?.ToString()); + Assert.Equal("new value", output?["newProperty"]?.ToString()); + } + + // TODO: Additional versioning tests + [Fact] + public async Task OrchestrationVersionPassedThroughContext() + { + var version = "0.1"; + await using HostTestLifetime server = await this.StartWorkerAsync(b => + { + b.AddTasks(tasks => tasks + .AddOrchestratorFunc("Versioned_Orchestration", (ctx, input) => + { + return ctx.CallActivityAsync("Versioned_Activity", ctx.Version); + }) + .AddActivityFunc("Versioned_Activity", (ctx, input) => + { + return $"Orchestration version: {input}"; + })); + }, c => + { + c.UseDefaultVersion(version); + }); + + var instanceId = await server.Client.ScheduleNewOrchestrationInstanceAsync("Versioned_Orchestration", input: string.Empty); + var result = await server.Client.WaitForInstanceCompletionAsync(instanceId, getInputsAndOutputs: true, this.TimeoutToken); + var output = result.ReadOutputAs(); + + Assert.NotNull(output); + Assert.Equal(output, $"Orchestration version: {version}"); + + } + + // TODO: Test for multiple external events with the same name + // TODO: Test for ContinueAsNew with external events that carry over + // TODO: Test for catching activity exceptions of specific types +}