diff --git a/Test/DurableTask.Core.Tests/TaskOrchestrationGetStatusAsyncTests.cs b/Test/DurableTask.Core.Tests/TaskOrchestrationGetStatusAsyncTests.cs new file mode 100644 index 000000000..8cf4ca65c --- /dev/null +++ b/Test/DurableTask.Core.Tests/TaskOrchestrationGetStatusAsyncTests.cs @@ -0,0 +1,139 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace DurableTask.Core.Tests +{ + using System; + using System.Threading.Tasks; + using DurableTask.Core.Serializing; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class TaskOrchestrationGetStatusAsyncTests + { + [TestMethod] + public async Task GetStatusAsync_ReturnsSerializedStatus_FromTypedOrchestration() + { + var orchestration = new SampleTypedOrchestration(); + + string syncStatus = orchestration.GetStatus(); + string asyncStatus = await orchestration.GetStatusAsync(); + + Assert.AreEqual(syncStatus, asyncStatus, "Async status should equal sync status"); + + var deserialized = (Status)orchestration.DataConverter.Deserialize(asyncStatus, typeof(Status)); + Assert.AreEqual("running", deserialized.State); + Assert.AreEqual(42, deserialized.Progress); + } + + [TestMethod] + public async Task GetStatusAsync_ReturnsNull_WhenTypedOnGetStatusReturnsNull() + { + var orchestration = new NullStatusTypedOrchestration(); + + string syncStatus = orchestration.GetStatus(); + string asyncStatus = await orchestration.GetStatusAsync(); + + Assert.IsNull(syncStatus, "Sync status should be null when OnGetStatus returns null"); + Assert.IsNull(asyncStatus, "Async status should be null when OnGetStatus returns null"); + } + + [TestMethod] + public async Task GetStatusAsync_ReturnsSameAsGetStatus_ForNonGenericOrchestration() + { + var orchestration = new NonGenericOrchestration(); + + string syncStatus = orchestration.GetStatus(); + string asyncStatus = await orchestration.GetStatusAsync(); + + Assert.AreEqual("OK", syncStatus); + Assert.AreEqual(syncStatus, asyncStatus); + } + + [TestMethod] + public async Task GetStatusAsync_PropagatesException_WhenGetStatusThrows() + { + var orchestration = new ThrowingStatusOrchestration(); + + await Assert.ThrowsExceptionAsync(async () => await orchestration.GetStatusAsync()); + } + + class SampleTypedOrchestration : TaskOrchestration + { + public override Task RunTask(OrchestrationContext context, string input) + { + return Task.FromResult("done"); + } + + public override Status OnGetStatus() + { + return new Status { State = "running", Progress = 42 }; + } + } + + class NullStatusTypedOrchestration : TaskOrchestration + { + public override Task RunTask(OrchestrationContext context, string input) + { + return Task.FromResult("done"); + } + + public override Status OnGetStatus() + { + return null; + } + } + + class NonGenericOrchestration : TaskOrchestration + { + public override Task Execute(OrchestrationContext context, string input) + { + return Task.FromResult("done"); + } + + public override void RaiseEvent(OrchestrationContext context, string name, string input) + { + } + + public override string GetStatus() + { + return "OK"; + } + } + + class ThrowingStatusOrchestration : TaskOrchestration + { + public override Task Execute(OrchestrationContext context, string input) + { + return Task.FromResult("done"); + } + + public override void RaiseEvent(OrchestrationContext context, string name, string input) + { + } + + public override string GetStatus() + { + throw new InvalidOperationException("boom"); + } + } + + class Status + { + public string State { get; set; } + + public int Progress { get; set; } + } + } +} + diff --git a/src/DurableTask.Core/TaskOrchestration.cs b/src/DurableTask.Core/TaskOrchestration.cs index c198c0855..8e6849556 100644 --- a/src/DurableTask.Core/TaskOrchestration.cs +++ b/src/DurableTask.Core/TaskOrchestration.cs @@ -11,6 +11,7 @@ // limitations under the License. // ---------------------------------------------------------------------------------- +#nullable enable namespace DurableTask.Core { using System; @@ -48,6 +49,28 @@ public abstract class TaskOrchestration /// /// The status public abstract string GetStatus(); + + /// + /// Gets the current status of the orchestration + /// + /// The status + public virtual async Task GetStatusAsync() + { + return await Task.FromResult(GetStatus()); + } + + /// + /// Raises an event in the orchestration asynchronously + /// + /// The orchestration context + /// Name for this event to be passed to the OnEvent handler + /// The serialized input + /// A task representing the asynchronous operation. + public virtual Task RaiseEventAsync(OrchestrationContext context, string name, string input) + { + RaiseEvent(context, name, input); + return Task.CompletedTask; + } } /// @@ -97,8 +120,8 @@ public override async Task Execute(OrchestrationContext context, string } catch (Exception e) when (!Utils.IsFatal(e) && !Utils.IsExecutionAborting(e)) { - string details = null; - FailureDetails failureDetails = null; + string? details = null; + FailureDetails? failureDetails = null; if (context.ErrorPropagationMode == ErrorPropagationMode.SerializeExceptions) { details = Utils.SerializeCause(e, DataConverter); @@ -164,7 +187,7 @@ public virtual void OnEvent(OrchestrationContext context, string name, TEvent in /// The typed status public virtual TStatus OnGetStatus() { - return default(TStatus); + return default!; } } } \ No newline at end of file