diff --git a/Microsoft.DurableTask.sln b/Microsoft.DurableTask.sln index 53f5f4bcf..8fc16da1f 100644 --- a/Microsoft.DurableTask.sln +++ b/Microsoft.DurableTask.sln @@ -44,6 +44,9 @@ EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Grpc.IntegrationTests", "test\Grpc.IntegrationTests\Grpc.IntegrationTests.csproj", "{7825CFEA-2923-4C44-BA36-8E16259B9777}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "samples", "samples", "{EFF7632B-821E-4CFC-B4A0-ED4B24296B17}" + ProjectSection(SolutionItems) = preProject + Directory.Packages.props = Directory.Packages.props + EndProjectSection EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AzureFunctionsApp", "samples\AzureFunctionsApp\AzureFunctionsApp.csproj", "{848FC5BD-4A99-4A0D-9099-9597700AA7BC}" EndProject @@ -101,6 +104,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "InProcessTestHost", "src\In EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "InProcessTestHost.Tests", "test\InProcessTestHost.Tests\InProcessTestHost.Tests.csproj", "{B894780C-338F-475E-8E84-56AFA8197A06}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DtsPortableSdkEntityTests", "samples\DtsPortableSdkEntityTests\DtsPortableSdkEntityTests.csproj", "{B2BAE32E-C558-BB99-DC82-282613525497}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -271,6 +276,10 @@ Global {FE1DA748-D6DB-E168-BC42-6DBBCEAF229C}.Debug|Any CPU.Build.0 = Debug|Any CPU {FE1DA748-D6DB-E168-BC42-6DBBCEAF229C}.Release|Any CPU.ActiveCfg = Release|Any CPU {FE1DA748-D6DB-E168-BC42-6DBBCEAF229C}.Release|Any CPU.Build.0 = Release|Any CPU + {B2BAE32E-C558-BB99-DC82-282613525497}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B2BAE32E-C558-BB99-DC82-282613525497}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B2BAE32E-C558-BB99-DC82-282613525497}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B2BAE32E-C558-BB99-DC82-282613525497}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -321,6 +330,7 @@ Global {B894780C-338F-475E-8E84-56AFA8197A06} = {E5637F81-2FB9-4CD7-900D-455363B142A7} {6EB9D002-62C8-D6C1-62A8-14C54CA6DBBC} = {EFF7632B-821E-4CFC-B4A0-ED4B24296B17} {FE1DA748-D6DB-E168-BC42-6DBBCEAF229C} = {8AFC9781-F6F1-4696-BB4A-9ED7CA9D612B} + {B2BAE32E-C558-BB99-DC82-282613525497} = {EFF7632B-821E-4CFC-B4A0-ED4B24296B17} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {AB41CB55-35EA-4986-A522-387AB3402E71} diff --git a/samples/DtsPortableSdkEntityTests/DtsPortableSdkEntityTests.csproj b/samples/DtsPortableSdkEntityTests/DtsPortableSdkEntityTests.csproj new file mode 100644 index 000000000..2efc614a2 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/DtsPortableSdkEntityTests.csproj @@ -0,0 +1,23 @@ + + + + net8.0 + enable + enable + + + + + + + + + false + + + + + + + + diff --git a/samples/DtsPortableSdkEntityTests/Program.cs b/samples/DtsPortableSdkEntityTests/Program.cs new file mode 100644 index 000000000..375bed00f --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/Program.cs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; +using Azure.Core; +using Azure.Identity; +using DtsPortableSdkEntityTests; +using DurableTask.Core.Entities; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.AzureManaged; +using Microsoft.DurableTask.Worker; +using Microsoft.DurableTask.Worker.AzureManaged; + +WebApplicationBuilder builder = WebApplication.CreateBuilder(args); + +string connectionString = builder.Configuration["DTS_CONNECTION_STRING"] ?? + // By default, use the connection string for the local development emulator + "Endpoint=http://localhost:8080;TaskHub=default;Authentication=None"; + +// Add all the generated orchestrations and activities automatically +builder.Services.AddDurableTaskWorker(builder => +{ + builder.AddTasks(r => + { + // TODO consider using source generator + + // register all orchestrations and activities used in the tests + HashSet registeredTestTypes = []; + foreach(var test in All.GetAllTests()) + { + if (!registeredTestTypes.Contains(test.GetType())) + { + test.Register(r, builder.Services); + registeredTestTypes.Add(test.GetType()); + } + } + + // register all entities + BatchEntity.Register(r); + Counter.Register(r); + FaultyEntity.Register(r); + Launcher.Register(r); + Relay.Register(r); + SchedulerEntity.Register(r); + SelfSchedulingEntity.Register(r); + StringStore.Register(r); + StringStore2.Register(r); + StringStore3.Register(r); + }); + + builder.UseDurableTaskScheduler(connectionString); +}); + +// Register the client, which can be used to start orchestrations +builder.Services.AddDurableTaskClient(builder => +{ + builder.UseDurableTaskScheduler(connectionString); +}); + +// Configure the HTTP request pipeline +builder.Services.AddControllers().AddJsonOptions(options => +{ + options.JsonSerializerOptions.Converters.Add(new JsonStringEnumConverter()); + options.JsonSerializerOptions.DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull; +}); + +// The actual listen URL can be configured in environment variables named "ASPNETCORE_URLS" or "ASPNETCORE_URLS_HTTPS" +WebApplication app = builder.Build(); +app.MapControllers(); +app.Run(); diff --git a/samples/DtsPortableSdkEntityTests/Properties/launchSettings.json b/samples/DtsPortableSdkEntityTests/Properties/launchSettings.json new file mode 100644 index 000000000..5acdcbf63 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/Properties/launchSettings.json @@ -0,0 +1,23 @@ +{ + "$schema": "https://json.schemastore.org/launchsettings.json", + "profiles": { + "http": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "http://localhost:5203", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + }, + "https": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "https://localhost:7225;http://localhost:5203", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + } + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/appsettings.Development.json b/samples/DtsPortableSdkEntityTests/appsettings.Development.json new file mode 100644 index 000000000..b6f634e59 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/appsettings.Development.json @@ -0,0 +1,9 @@ +{ + "DetailedErrors": true, + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/appsettings.json b/samples/DtsPortableSdkEntityTests/appsettings.json new file mode 100644 index 000000000..ec04bc120 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/appsettings.json @@ -0,0 +1,9 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + }, + "AllowedHosts": "*" +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/common/ProblematicObject.cs b/samples/DtsPortableSdkEntityTests/common/ProblematicObject.cs new file mode 100644 index 000000000..be8fa030e --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/common/ProblematicObject.cs @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.Serialization; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading.Tasks; +using Azure.Core.Serialization; + +namespace DtsPortableSdkEntityTests +{ + internal static class CustomSerialization + { + public static ProblematicObject CreateUnserializableObject() + { + return new ProblematicObject(serializable: false, deserializable: false); + } + + public static ProblematicObject CreateUndeserializableObject() + { + return new ProblematicObject(serializable: true, deserializable: false); + } + + /// + /// An object for which we can inject errors on serialization or deserialization, to test + // how those are handled by the framework. + /// + public class ProblematicObject + { + public ProblematicObject(bool serializable = true, bool deserializable = true) + { + this.Serializable = serializable; + this.Deserializable = deserializable; + } + + public bool Serializable { get; set; } + + public bool Deserializable { get; set; } + } + + public class ProblematicObjectJsonConverter : JsonConverter + { + public override ProblematicObject Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options) + { + bool deserializable = reader.GetBoolean(); + if (!deserializable) + { + throw new JsonException("problematic object: is not deserializable"); + } + return new ProblematicObject(serializable: true, deserializable: true); + } + + public override void Write( + Utf8JsonWriter writer, + ProblematicObject value, + JsonSerializerOptions options) + { + if (!value.Serializable) + { + throw new JsonException("problematic object: is not serializable"); + } + writer.WriteBooleanValue(value.Deserializable); + } + } + } +} diff --git a/samples/DtsPortableSdkEntityTests/common/Test.cs b/samples/DtsPortableSdkEntityTests/common/Test.cs new file mode 100644 index 000000000..a6cd83da6 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/common/Test.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; + +namespace DtsPortableSdkEntityTests; + +internal abstract class Test +{ + public virtual string Name => this.GetType().Name; + + public abstract Task RunAsync(TestContext context); + + public virtual TimeSpan Timeout => TimeSpan.FromSeconds(30); + + public virtual void Register(DurableTaskRegistry registry, IServiceCollection services) + { + } +} diff --git a/samples/DtsPortableSdkEntityTests/common/TestContext.cs b/samples/DtsPortableSdkEntityTests/common/TestContext.cs new file mode 100644 index 000000000..86751b25b --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/common/TestContext.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; + +namespace DtsPortableSdkEntityTests; + +internal class TestContext +{ + public TestContext(DurableTaskClient client, ILogger logger, CancellationToken cancellationToken) + { + this.Client = client; + this.Logger = logger; + this.CancellationToken = cancellationToken; + } + + public DurableTaskClient Client { get; } + + public ILogger Logger { get; } + + public CancellationToken CancellationToken { get; set; } + + public bool BackendSupportsImplicitEntityDeletion { get; set; } = true; // false for Azure Storage, true for Netherite, MSSQL, and DTS +} diff --git a/samples/DtsPortableSdkEntityTests/common/TestContextExtensions.cs b/samples/DtsPortableSdkEntityTests/common/TestContextExtensions.cs new file mode 100644 index 000000000..0ebc97a9f --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/common/TestContextExtensions.cs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; + +namespace DtsPortableSdkEntityTests; + +internal static class TestContextExtensions +{ + public static async Task WaitForEntityStateAsync( + this TestContext context, + EntityInstanceId entityInstanceId, + TimeSpan? timeout = null, + Func? describeWhatWeAreWaitingFor = null) + { + if (timeout == null) + { + timeout = Debugger.IsAttached ? TimeSpan.FromMinutes(5) : TimeSpan.FromSeconds(30); + } + + Stopwatch sw = Stopwatch.StartNew(); + + EntityMetadata? response; + + do + { + response = await context.Client.Entities.GetEntityAsync(entityInstanceId, includeState: true); + + if (response != null) + { + if (describeWhatWeAreWaitingFor == null) + { + break; + } + else + { + var waitForResult = describeWhatWeAreWaitingFor(response.State.ReadAs()); + + if (string.IsNullOrEmpty(waitForResult)) + { + break; + } + else + { + context.Logger.LogInformation($"Waiting for {entityInstanceId} : {waitForResult}"); + } + } + } + else + { + context.Logger.LogInformation($"Waiting for {entityInstanceId} to have state."); + } + + await Task.Delay(TimeSpan.FromMilliseconds(100)); + } + while (sw.Elapsed < timeout); + + if (response != null) + { + string serializedState = response.State.Value; + context.Logger.LogInformation($"Found state: {serializedState}"); + return response.State.ReadAs(); + } + else + { + throw new TimeoutException($"Durable entity '{entityInstanceId}' still doesn't have any state!"); + } + } +} diff --git a/samples/DtsPortableSdkEntityTests/common/TestRunner.cs b/samples/DtsPortableSdkEntityTests/common/TestRunner.cs new file mode 100644 index 000000000..428f0004e --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/common/TestRunner.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System.Linq; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace DtsPortableSdkEntityTests; + +internal static class TestRunner +{ + public static async Task RunAsync(TestContext context, string? prefix = null, bool listOnly = false) + { + var output = new StringBuilder(); + + foreach (var test in All.GetAllTests()) + { + if (prefix == null || test.Name.ToLowerInvariant().StartsWith(prefix.ToLowerInvariant())) + { + if (listOnly) + { + output.AppendLine(test.Name); + } + else + { + context.Logger.LogWarning("------------ starting {testName}", test.Name); + + // if debugging, time out after 60m + // otherwise, time out either when the http request times out or when the individual test time limit is exceeded + using CancellationTokenSource cancellationTokenSource + = Debugger.IsAttached ? new() : CancellationTokenSource.CreateLinkedTokenSource(context.CancellationToken); + cancellationTokenSource.CancelAfter(Debugger.IsAttached ? TimeSpan.FromMinutes(60) : test.Timeout); + context.CancellationToken = cancellationTokenSource.Token; + + try + { + await test.RunAsync(context); + output.AppendLine($"PASSED {test.Name}"); + } + catch (Exception ex) + { + context.Logger.LogError(ex, "test {testName} failed", test.Name); + output.AppendLine($"FAILED {test.Name} {ex.ToString()}"); + break; + } + } + } + } + + return output.ToString(); + } +} diff --git a/samples/DtsPortableSdkEntityTests/common/Utils.cs b/samples/DtsPortableSdkEntityTests/common/Utils.cs new file mode 100644 index 000000000..4da818306 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/common/Utils.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace DtsPortableSdkEntityTests; + +static class Utils +{ + public static async Task ParallelForEachAsync(this IEnumerable items, int maxConcurrency, Func action) + { + List tasks; + if (items is ICollection itemCollection) + { + tasks = new List(itemCollection.Count); + } + else + { + tasks = []; + } + + using SemaphoreSlim semaphore = new(maxConcurrency); + foreach (T item in items) + { + tasks.Add(InvokeThrottledAction(item, action, semaphore)); + } + + await Task.WhenAll(tasks); + } + + static async Task InvokeThrottledAction(T item, Func action, SemaphoreSlim semaphore) + { + await semaphore.WaitAsync(); + try + { + await action(item); + } + finally + { + semaphore.Release(); + } + } +} diff --git a/samples/DtsPortableSdkEntityTests/controllers/BenchmarksController.cs b/samples/DtsPortableSdkEntityTests/controllers/BenchmarksController.cs new file mode 100644 index 000000000..47075fa57 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/controllers/BenchmarksController.cs @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Security.Cryptography; +using DurableTask.Core.Entities; +using Microsoft.AspNetCore.Mvc; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; + +namespace DtsPortableSdkEntityTests; + + +[Route("benchmarks")] +[ApiController] +public class BenchmarksController( + DurableTaskClient durableTaskClient, + ILogger logger) : ControllerBase +{ + readonly DurableTaskClient durableTaskClient = durableTaskClient; + readonly ILogger logger = logger; + + // we are planning to create some benchmarks here at some point but for now these are just very basic entity tests + // that allow us to read/update/delete a counter entity via a simple REST-like api + + // POST http://localhost:5008/benchmarks/counter/xyz/increment + [HttpPost("counter/{key}/increment")] + public async Task CounterIncrement([FromRoute] string key) + { + if (string.IsNullOrEmpty(key)) + { + return BadRequest(new { error = "The 'key' route parameter must not be empty." }); + } + + EntityInstanceId entityId = new(nameof(Counter), key); + + logger.LogInformation("Sending signal 'Increment' to {entityId}.", entityId); + + Stopwatch sw = Stopwatch.StartNew(); + + await durableTaskClient.Entities.SignalEntityAsync(entityId, nameof(Counter.Increment)); + + sw.Stop(); + + logger.LogInformation( + "Sent signal 'Increment' to {entityId} in {time}ms!", + entityId, + sw.Elapsed.TotalMilliseconds); + + return Ok(new + { + message = $"Sent signal 'Increment' to {entityId} in {sw.Elapsed.TotalMilliseconds:F3}ms." + }); + } + + // GET http://localhost:5008/benchmarks/counter/xyz + [HttpGet("counter/{key}")] + public async Task CounterGet([FromRoute] string key) + { + if (string.IsNullOrEmpty(key)) + { + return BadRequest(new { error = "The 'key' route parameter must not be empty." }); + } + + EntityInstanceId entityId = new(nameof(Counter), key); + + logger.LogInformation("Reading state of {entityId}.", entityId); + + EntityMetadata? entityMetadata = + await durableTaskClient.Entities.GetEntityAsync(entityId); + + if (entityMetadata == null) + { + return NotFound(new + { + message = $"Entity {entityId} does not exist." + }); + } + else + { + return Ok(new + { + message = $"Entity {entityId} has state {entityMetadata.State}." + }); + } + } + + // DELETE http://localhost:5008/benchmarks/counter/xyz + [HttpDelete("counter/{key}")] + public async Task CounterDelete([FromRoute] string key) + { + if (string.IsNullOrEmpty(key)) + { + return BadRequest(new { error = "The 'key' route parameter must not be empty." }); + } + + EntityInstanceId entityId = new(nameof(Counter), key); + + logger.LogInformation("Deleting state of {entityId}.", entityId); + + await durableTaskClient.Entities.SignalEntityAsync(entityId, "delete"); + + return Ok(new + { + message = $"Sent deletion signal to {entityId}." + }); + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/controllers/TestsController.cs b/samples/DtsPortableSdkEntityTests/controllers/TestsController.cs new file mode 100644 index 000000000..880e913bf --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/controllers/TestsController.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Runtime.Serialization; +using System.Text; +using System.Text.RegularExpressions; +using Azure.Core; +using Microsoft.AspNetCore.Mvc; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; + +namespace DtsPortableSdkEntityTests; + + +[Route("tests")] +[ApiController] +public class TestsController( + DurableTaskClient durableTaskClient, + ILogger logger) : ControllerBase +{ + readonly DurableTaskClient durableTaskClient = durableTaskClient; + readonly ILogger logger = logger; + + // HTTPie command: + // http POST http://localhost:5008/tests?prefix=xyz + [HttpPost()] + public async Task RunTests([FromQuery] string? prefix) + { + var context = new TestContext(this.durableTaskClient, this.logger, CancellationToken.None); + string result = await TestRunner.RunAsync(context, prefix); + return this.Ok(result); + } + + // HTTPie command: + // http GET http://localhost:5008/tests?prefix=xyz + [HttpGet()] + public async Task ListTests([FromQuery] string? prefix) + { + var context = new TestContext(this.durableTaskClient, this.logger, CancellationToken.None); + string result = await TestRunner.RunAsync(context, prefix, listOnly: true); + return this.Ok(result); + } +} diff --git a/samples/DtsPortableSdkEntityTests/demo.http b/samples/DtsPortableSdkEntityTests/demo.http new file mode 100644 index 000000000..7f65b932e --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/demo.http @@ -0,0 +1,16 @@ +# For more info on HTTP files go to https://aka.ms/vs/httpfile + +### List all tests +GET http://localhost:5203/tests + + +### List all tests with prefix 'EntityQueries' +GET http://localhost:5203/tests?prefix=EntityQueries + + +### Run all tests +POST http://localhost:5203/tests + + +### Run all tests with prefix 'EntityQueries' +POST http://localhost:5203/tests?prefix=EntityQueries diff --git a/samples/DtsPortableSdkEntityTests/entities/BatchEntity.cs b/samples/DtsPortableSdkEntityTests/entities/BatchEntity.cs new file mode 100644 index 000000000..694176f09 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/entities/BatchEntity.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Threading.Tasks.Dataflow; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Entities; + +namespace DtsPortableSdkEntityTests; + +/// +/// An entity that records all batch positions and batch sizes +/// +class BatchEntity : ITaskEntity +{ + int operationCounter; + + public ValueTask RunAsync(TaskEntityOperation operation) + { + List? state = (List?) operation.State.GetState(typeof(List)); + int batchNo; + if (state == null) + { + batchNo = 0; + state = new List(); + } + else if (operationCounter == 0) + { + batchNo = state.Last().batch + 1; + } + else + { + batchNo = state.Last().batch; + } + + state.Add(new Entry(batchNo, operationCounter++)); + operation.State.SetState(state); + return default; + } + + public record struct Entry(int batch, int operation); + + public static void Register(DurableTaskRegistry r) + { + r.AddEntity(nameof(BatchEntity), _ => new BatchEntity()); + } +} diff --git a/samples/DtsPortableSdkEntityTests/entities/Counter.cs b/samples/DtsPortableSdkEntityTests/entities/Counter.cs new file mode 100644 index 000000000..ede343ccb --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/entities/Counter.cs @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Entities; + +namespace DtsPortableSdkEntityTests; + +class Counter : TaskEntity +{ + public void Increment() + { + this.State++; + } + + public void Add(int amount) + { + this.State += amount; + } + + public int Get() + { + return this.State; + } + + public void Set(int value) + { + this.State = value; + } + + public static void Register(DurableTaskRegistry r) + { + r.AddEntity(nameof(Counter), _ => new Counter()); + } +} diff --git a/samples/DtsPortableSdkEntityTests/entities/FaultyEntity.cs b/samples/DtsPortableSdkEntityTests/entities/FaultyEntity.cs new file mode 100644 index 000000000..4e3b99916 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/entities/FaultyEntity.cs @@ -0,0 +1,199 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading.Tasks; +using DurableTask.Core.Entities; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; + +namespace DtsPortableSdkEntityTests; + +// we use a low-level ITaskEntity so we can intercept some of the operations without going through +// the default sequence of serialization and deserialization of state. This is needed to construct +// this type of test, it does not reflect typical useage. +public class FaultyEntity : ITaskEntity +{ + class State + { + [JsonInclude] + public int Value { get; set; } + + [JsonInclude] + [JsonConverter(typeof(CustomSerialization.ProblematicObjectJsonConverter))] + public CustomSerialization.ProblematicObject? ProblematicObject { get; set; } + + [JsonInclude] + public int NumberIncrementsSent { get; set; } + + public Task Send(EntityInstanceId target, TaskEntityContext context) + { + var desc = $"{++this.NumberIncrementsSent}:{this.Value}"; + context.SignalEntity(target, desc); + return Task.CompletedTask; + } + } + + public static void Register(DurableTaskRegistry r) + { + r.AddEntity(nameof(FaultyEntity), _ => new FaultyEntity()); + } + + public static void ThrowTestException() + { + throw new TestException("KABOOM"); + } + + [Serializable] + public class TestException : Exception + { + public TestException() : base() { } + public TestException(string message) : base(message) { } + public TestException(string message, Exception inner) : base(message, inner) { } + } + + public async ValueTask RunAsync(TaskEntityOperation operation) + { + State? Get() + { + return (State?)operation.State.GetState(typeof(State)); + } + State GetOrCreate() + { + State? s = Get(); + if (s is null) + { + operation.State.SetState(s = new State()); + } + return s; + } + + switch (operation.Name) + { + case "Exists": + { + try + { + return Get() != null; + } + catch (Exception) // the entity has state, even if that state is corrupted + { + return true; + } + } + case "Delay": + { + double delayInSeconds = (double)operation.GetInput(typeof(double))!; + await Task.Delay(TimeSpan.FromSeconds(delayInSeconds)); + return default; + } + case "Delete": + { + operation.State.SetState(null); + return default; + } + case "DeleteWithoutReading": + { + // do not read the state first otherwise the deserialization may throw before we can delete it + operation.State.SetState(null); + return default; + } + case "DeleteThenThrow": + { + operation.State.SetState(null); + ThrowTestException(); + return default; + } + case "Throw": + { + ThrowTestException(); + return default; + } + case "ThrowNested": + { + try + { + ThrowTestException(); + } + catch (Exception e) + { + throw new Exception("KABOOOOOM", e); + } + return default; + } + case "Get": + { + return GetOrCreate().Value; + } + case "GetNumberIncrementsSent": + { + return GetOrCreate().NumberIncrementsSent; + } + case "Set": + { + State state = GetOrCreate(); + state.Value = (int)operation.GetInput(typeof(int))!; + operation.State.SetState(state); + return default; + } + case "SetToUnserializable": + { + State state = GetOrCreate(); + state.ProblematicObject = CustomSerialization.CreateUnserializableObject(); + operation.State.SetState(state); + return default; + } + case "SetToUndeserializable": + { + State state = GetOrCreate(); + state.ProblematicObject = CustomSerialization.CreateUndeserializableObject(); + operation.State.SetState(state); + return default; + } + case "SetThenThrow": + { + State state = GetOrCreate(); + state.Value = (int)operation.GetInput(typeof(int))!; + operation.State.SetState(state); + ThrowTestException(); + return default; + } + case "Send": + { + State state = GetOrCreate(); + EntityInstanceId entityId = (EntityInstanceId)operation.GetInput(typeof(EntityId))!; + await state.Send(entityId, operation.Context); + operation.State.SetState(state); + return default; + } + case "SendThenThrow": + { + State state = GetOrCreate(); + EntityInstanceId entityId = (EntityInstanceId)operation.GetInput(typeof(EntityId))!; + await state.Send(entityId, operation.Context); + operation.State.SetState(state); + ThrowTestException(); + return default; + } + case "SendThenMakeUnserializable": + { + State state = GetOrCreate(); + EntityInstanceId entityId = (EntityInstanceId)operation.GetInput(typeof(EntityId))!; + await state.Send(entityId, operation.Context); + state.ProblematicObject = CustomSerialization.CreateUnserializableObject(); + operation.State.SetState(state); + return default; + } + default: + { + throw new InvalidOperationException($"undefined entity operation: {operation.Name}"); + } + } + } +} diff --git a/samples/DtsPortableSdkEntityTests/entities/Launcher.cs b/samples/DtsPortableSdkEntityTests/entities/Launcher.cs new file mode 100644 index 000000000..e4741ff0b --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/entities/Launcher.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Castle.Core.Logging; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Entities; + +namespace DtsPortableSdkEntityTests; + +class Launcher +{ + public string? OrchestrationInstanceId { get; set; } + + public DateTime? ScheduledTime { get; set; } + + public bool IsDone { get; set; } + + public string? ErrorMessage { get; set; } + + public void Launch(TaskEntityContext context, DateTime? scheduledTime = null) + { + this.OrchestrationInstanceId = context.ScheduleNewOrchestration( + nameof(FireAndForget.SignallingOrchestration), + context.Id, + new StartOrchestrationOptions(StartAt: scheduledTime)); + } + + public string? Get() + { + if (this.ErrorMessage != null) + { + throw new Exception(this.ErrorMessage); + } + return this.IsDone ? this.OrchestrationInstanceId : null; + } + + public void Done() + { + this.IsDone = true; + + if (this.ScheduledTime != null) + { + DateTime now = DateTime.UtcNow; + if (now < this.ScheduledTime) + { + this.ErrorMessage = $"delay was too short, expected >= {this.ScheduledTime}, actual = {now}"; + } + } + } + + public static void Register(DurableTaskRegistry r) + { + r.AddEntity(nameof(Launcher), _ => new Wrapper()); + } + + class Wrapper : TaskEntity + { + protected override bool AllowStateDispatch => true; + } +} diff --git a/samples/DtsPortableSdkEntityTests/entities/Relay.cs b/samples/DtsPortableSdkEntityTests/entities/Relay.cs new file mode 100644 index 000000000..2b1eead25 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/entities/Relay.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; + +namespace DtsPortableSdkEntityTests; + +/// +/// A stateless entity that forwards signals +/// +class Relay : ITaskEntity +{ + public record Input(EntityInstanceId entityInstanceId, string operationName, object? input, DateTimeOffset? scheduledTime); + + public ValueTask RunAsync(TaskEntityOperation operation) + { + T GetInput() => (T)operation.GetInput(typeof(T))!; + + Input input = GetInput(); + + operation.Context.SignalEntity( + input.entityInstanceId, + input.operationName, + input.input, + new SignalEntityOptions() { SignalTime = input.scheduledTime }); + + return default; + } + + public static void Register(DurableTaskRegistry r) + { + r.AddEntity(nameof(Relay), _ => new Relay()); + } +} diff --git a/samples/DtsPortableSdkEntityTests/entities/SchedulerEntity.cs b/samples/DtsPortableSdkEntityTests/entities/SchedulerEntity.cs new file mode 100644 index 000000000..b7331e34f --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/entities/SchedulerEntity.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; + +namespace DtsPortableSdkEntityTests; + +class SchedulerEntity : ITaskEntity +{ + private readonly ILogger logger; + + public SchedulerEntity(ILogger logger) + { + this.logger = logger; + } + + public ValueTask RunAsync(TaskEntityOperation operation) + { + this.logger.LogInformation("{entityId} received {operationName} signal", operation.Context.Id, operation.Name); + + List state = (List?)operation.State.GetState(typeof(List)) ?? new List(); + + if (state.Contains(operation.Name)) + { + this.logger.LogError($"duplicate: {operation.Name}"); + } + else + { + state.Add(operation.Name); + } + + return default; + } + + public static void Register(DurableTaskRegistry r) + { + r.AddEntity( + nameof(SchedulerEntity), + (IServiceProvider serviceProvider) => + (ITaskEntity)ActivatorUtilities.GetServiceOrCreateInstance(serviceProvider)!); + } +} diff --git a/samples/DtsPortableSdkEntityTests/entities/SelfSchedulingEntity.cs b/samples/DtsPortableSdkEntityTests/entities/SelfSchedulingEntity.cs new file mode 100644 index 000000000..fb3fcd637 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/entities/SelfSchedulingEntity.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Entities; + +namespace DtsPortableSdkEntityTests; + +public class SelfSchedulingEntity +{ + public string Value { get; set; } = ""; + + public void Start(TaskEntityContext context) + { + var now = DateTime.UtcNow; + + var timeA = DateTimeOffset.UtcNow + TimeSpan.FromSeconds(1); + var timeB = DateTimeOffset.UtcNow + TimeSpan.FromSeconds(2); + var timeC = DateTimeOffset.UtcNow + TimeSpan.FromSeconds(3); + var timeD = DateTimeOffset.UtcNow + TimeSpan.FromSeconds(4); + + context.SignalEntity(context.Id, nameof(D), options: timeD); + context.SignalEntity(context.Id, nameof(C), options: timeC); + context.SignalEntity(context.Id, nameof(B), options: timeB); + context.SignalEntity(context.Id, nameof(A), options: timeA); + } + + public void A() + { + this.Value += "A"; + } + + public Task B() + { + this.Value += "B"; + return Task.Delay(100); + } + + public void C() + { + this.Value += "C"; + } + + public Task D() + { + this.Value += "D"; + return Task.FromResult(111); + } + + public static void Register(DurableTaskRegistry r) + { + r.AddEntity(nameof(SelfSchedulingEntity), _ => new Wrapper()); + } + + class Wrapper : TaskEntity + { + protected override bool AllowStateDispatch => true; + } +} diff --git a/samples/DtsPortableSdkEntityTests/entities/StringStore.cs b/samples/DtsPortableSdkEntityTests/entities/StringStore.cs new file mode 100644 index 000000000..bdd3d7ef6 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/entities/StringStore.cs @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Text.Json.Serialization; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; + +namespace DtsPortableSdkEntityTests; + +// three variations of the same simple entity: an entity that stores a string +// supporting get, set, and delete operations. There are slight semantic differences. + +//-------------- a class-based implementation ----------------- + +public class StringStore +{ + [JsonInclude] + public string Value { get; set; } + + public StringStore() + { + this.Value = string.Empty; + } + + public string Get() + { + return this.Value; + } + + public void Set(string value) + { + this.Value = value; + } + + // Delete is implicitly defined + + public static void Register(DurableTaskRegistry r) + { + r.AddEntity(nameof(StringStore), _ => new Wrapper()); + } + + class Wrapper : TaskEntity + { + protected override bool AllowStateDispatch => true; + } +} + +//-------------- a TaskEntity-based implementation ----------------- + +public class StringStore2 : TaskEntity +{ + public string Get() + { + return this.State; + } + + public void Set(string value) + { + this.State = value; + } + + protected override string InitializeState(TaskEntityOperation operation) + { + return string.Empty; + } + + // Delete is implicitly defined + + public static void Register(DurableTaskRegistry r) + { + r.AddEntity(nameof(StringStore2), _ => new StringStore2()); + } +} + +//-------------- a direct ITaskEntity-based implementation ----------------- + +class StringStore3 : ITaskEntity +{ + public ValueTask RunAsync(TaskEntityOperation operation) + { + switch (operation.Name) + { + case "set": + operation.State.SetState((string?)operation.GetInput(typeof(string))); + return default; + + case "get": + // note: this does not assign a state to the entity if it does not already exist + return new ValueTask((string?)operation.State.GetState(typeof(string))); + + case "delete": + if (operation.State.GetState(typeof(string)) == null) + { + return new ValueTask(false); + } + else + { + operation.State.SetState(null); + return new ValueTask(true); + } + + default: + throw new NotImplementedException("no such operation"); + } + } + + public static void Register(DurableTaskRegistry r) + { + r.AddEntity(nameof(StringStore3), _ => new StringStore3()); + } +} + diff --git a/samples/DtsPortableSdkEntityTests/tests/All.cs b/samples/DtsPortableSdkEntityTests/tests/All.cs new file mode 100644 index 000000000..27b2b6bf9 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/All.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Runtime.Serialization; +using System.Text; +using System.Text.RegularExpressions; +using Azure.Core; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +/// +/// A collection containing all the unit tests. +/// +static class All +{ + public static IEnumerable GetAllTests() + { + yield return new SetAndGet(); + yield return new CallCounter(); + yield return new BatchedEntitySignals(100); + yield return new SignalAndCall(typeof(StringStore)); + yield return new SignalAndCall(typeof(StringStore2)); + yield return new SignalAndCall(typeof(StringStore3)); + yield return new CallAndDelete(typeof(StringStore)); + yield return new CallAndDelete(typeof(StringStore2)); + yield return new CallAndDelete(typeof(StringStore3)); + yield return new DeleteAfterLock(); + yield return new SignalThenPoll(direct: true, delayed: false); + yield return new SignalThenPoll(direct: true, delayed: true); + yield return new SignalThenPoll(direct: false, delayed: false); + yield return new SignalThenPoll(direct: false, delayed: true); + yield return new SelfScheduling(); + yield return new FireAndForget(null); + yield return new FireAndForget(0); + yield return new FireAndForget(5); + yield return new SingleLockedTransfer(); + yield return new MultipleLockedTransfers(2); + yield return new MultipleLockedTransfers(5); + yield return new MultipleLockedTransfers(100); + yield return new TwoCriticalSections(sameEntity: true); + yield return new TwoCriticalSections(sameEntity: false); + yield return new FaultyCriticalSection(); + yield return new LargeEntity(); + yield return new CallSingleFaultyEntity(); + yield return new CallMultipleFaultyEntities(); + yield return new CallFaultyActivity(nested: false); + yield return new CallFaultyActivity(nested: true); + yield return new CallFaultySuborchestration(nested: false); + yield return new CallFaultySuborchestration(nested: true); + yield return new InvalidEntityId(InvalidEntityId.Location.ClientGet); + yield return new InvalidEntityId(InvalidEntityId.Location.ClientSignal); + yield return new InvalidEntityId(InvalidEntityId.Location.OrchestrationCall); + yield return new InvalidEntityId(InvalidEntityId.Location.OrchestrationSignal); + yield return new EntityQueries1(); + yield return new EntityQueries2(); + yield return new NoOrphanedLockAfterTermination(); + yield return new NoOrphanedLockAfterNondeterminism(); + } + +} diff --git a/samples/DtsPortableSdkEntityTests/tests/BatchedEntitySignals.cs b/samples/DtsPortableSdkEntityTests/tests/BatchedEntitySignals.cs new file mode 100644 index 000000000..4f61aa04c --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/BatchedEntitySignals.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Mime; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class BatchedEntitySignals : Test +{ + readonly int numIterations; + + public BatchedEntitySignals(int numIterations) + { + this.numIterations = numIterations; + } + + public override async Task RunAsync(TestContext context) + { + var entityId = new EntityInstanceId(nameof(BatchEntity), Guid.NewGuid().ToString().Substring(0,8)); + + // send a number of signals immediately after each other + List tasks = new List(); + for (int i = 0; i < numIterations; i++) + { + tasks.Add(context.Client.Entities.SignalEntityAsync(entityId, string.Empty, i)); + } + + await Task.WhenAll(tasks); + + var result = await context.WaitForEntityStateAsync>( + entityId, + timeout: default, + list => list.Count == this.numIterations ? null : $"waiting for {this.numIterations - list.Count} signals"); + + Assert.Equal(new BatchEntity.Entry(0, 0), result[0]); + Assert.Equal(this.numIterations, result.Count); + + for (int i = 0; i < numIterations - 1; i++) + { + if (result[i].batch == result[i + 1].batch) + { + Assert.Equal(result[i].operation + 1, result[i + 1].operation); + } + else + { + Assert.Equal(result[i].batch + 1, result[i + 1].batch); + Assert.Equal(0, result[i + 1].operation); + } + } + + // there should always be some batching going on + int numBatches = result.Last().batch + 1; + Assert.True(numBatches < numIterations); + context.Logger.LogInformation($"completed {numIterations} operations in {numBatches} batches"); + } +} diff --git a/samples/DtsPortableSdkEntityTests/tests/CallAndDelete.cs b/samples/DtsPortableSdkEntityTests/tests/CallAndDelete.cs new file mode 100644 index 000000000..46d9cf302 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/CallAndDelete.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; +using DurableTask.Core.Entities; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Entities; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class CallAndDelete : Test +{ + private readonly Type stringStoreType; + + public CallAndDelete(Type stringStoreType) + { + this.stringStoreType = stringStoreType; + } + + public override string Name => $"{base.Name}.{this.stringStoreType.Name}"; + + public override async Task RunAsync(TestContext context) + { + var entityId = new EntityInstanceId(this.stringStoreType.Name, Guid.NewGuid().ToString()); + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(nameof(CallAndDeleteOrchestration), entityId); + var metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, true); + + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("ok", metadata.ReadOutputAs()); + + // check that entity was deleted + var entityMetadata = await context.Client.Entities.GetEntityAsync(entityId); + Assert.Null(entityMetadata); + } + + static bool GetOperationInitializesEntity(EntityInstanceId entityInstanceId) + => !string.Equals(entityInstanceId.Name, nameof(StringStore3).ToLowerInvariant(), StringComparison.InvariantCulture); + + static bool DeleteReturnsBoolean(EntityInstanceId entityInstanceId) + => string.Equals(entityInstanceId.Name, nameof(StringStore3).ToLowerInvariant(), StringComparison.InvariantCulture); + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + } + + public class CallAndDeleteOrchestration : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId entityId) + { + await context.Entities.CallEntityAsync(entityId, "set", "333"); + + string value = await context.Entities.CallEntityAsync(entityId, "get"); + Assert.Equal("333", value); + + if (DeleteReturnsBoolean(entityId)) + { + bool deleted = await context.Entities.CallEntityAsync(entityId, "delete"); + Assert.True(deleted); + + bool deletedAgain = await context.Entities.CallEntityAsync(entityId, "delete"); + Assert.False(deletedAgain); + } + else + { + await context.Entities.CallEntityAsync(entityId, "delete"); + } + + string getValue = await context.Entities.CallEntityAsync(entityId, "get"); + if (GetOperationInitializesEntity(entityId)) + { + Assert.Equal("", getValue); + } + else + { + Assert.Null(getValue); + } + + if (DeleteReturnsBoolean(entityId)) + { + bool deletedAgain = await context.Entities.CallEntityAsync(entityId, "delete"); + if (GetOperationInitializesEntity(entityId)) + { + Assert.True(deletedAgain); + } + else + { + Assert.False(deletedAgain); + } + } + else + { + await context.Entities.CallEntityAsync(entityId, "delete"); + } + + return "ok"; + } + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/tests/CallCounter.cs b/samples/DtsPortableSdkEntityTests/tests/CallCounter.cs new file mode 100644 index 000000000..0f7ac7c6c --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/CallCounter.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class CallCounter : Test +{ + public override async Task RunAsync(TestContext context) + { + var entityId = new EntityInstanceId(nameof(Counter), Guid.NewGuid().ToString()); + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(nameof(CallCounterOrchestration), entityId); + var metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, true); + + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("OK", metadata.ReadOutputAs()); + + // entity ids cannot be used for orchestration instance queries + await Assert.ThrowsAsync(() => context.Client.GetInstanceAsync(entityId.ToString())); + + // and are not returned by them + List results = context.Client.GetAllInstancesAsync().ToBlockingEnumerable().ToList(); + Assert.DoesNotContain(results, metadata => metadata.InstanceId.StartsWith("@")); + + // check that entity state is correct + EntityMetadata? entityMetadata = await context.Client.Entities.GetEntityAsync(entityId, includeState:true); + Assert.NotNull(entityMetadata); + Assert.Equal(33, entityMetadata!.State); + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + } + + public class CallCounterOrchestration : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId entityId) + { + + await context.Entities.CallEntityAsync(entityId, "set", 33); + int result = await context.Entities.CallEntityAsync(entityId, "get"); + + if (result == 33) + { + return "OK"; + } + else + { + return $"wrong result: {result} instead of 33"; + } + } + } +} diff --git a/samples/DtsPortableSdkEntityTests/tests/CallFaultyActivity.cs b/samples/DtsPortableSdkEntityTests/tests/CallFaultyActivity.cs new file mode 100644 index 000000000..bdb977b56 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/CallFaultyActivity.cs @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using DurableTask.Core.Entities; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class CallFaultyActivity : Test +{ + // this is not an entity test... but it's a good place to put this test + + private readonly bool nested; + + public CallFaultyActivity(bool nested) + { + this.nested = nested; + } + public override string Name => $"{base.Name}.{(this.nested ? "Nested" : "NotNested")}"; + + public override async Task RunAsync(TestContext context) + { + string orchestrationName = nameof(CallFaultyActivityOrchestration); + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(orchestrationName, this.nested); + var metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, true); + + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("ok", metadata.ReadOutputAs()); + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddActivity(); + registry.AddOrchestrator(); + } +} + +class FaultyActivity : TaskActivity +{ + public override Task RunAsync(TaskActivityContext context, bool nested) + { + if (!nested) + { + this.MethodThatThrowsException(); + } + else + { + this.MethodThatThrowsNestedException(); + } + + return Task.FromResult("unreachable"); + } + + public void MethodThatThrowsNestedException() + { + try + { + this.MethodThatThrowsException(); + } + catch (Exception e) + { + throw new Exception("KABOOOOOM", e); + } + } + + public void MethodThatThrowsException() + { + throw new Exception("KABOOM"); + } +} + +class CallFaultyActivityOrchestration : TaskOrchestrator +{ + public override async Task RunAsync(TaskOrchestrationContext context, bool nested) + { + try + { + await context.CallActivityAsync(nameof(FaultyActivity), nested); + throw new Exception("expected activity to throw exception, but none was thrown"); + } + catch (TaskFailedException taskFailedException) + { + Assert.NotNull(taskFailedException.FailureDetails); + + if (!nested) + { + Assert.Equal("KABOOM", taskFailedException.FailureDetails.ErrorMessage); + Assert.Contains(nameof(FaultyActivity.MethodThatThrowsException), taskFailedException.FailureDetails.StackTrace); + } + else + { + Assert.Equal("KABOOOOOM", taskFailedException.FailureDetails.ErrorMessage); + Assert.Contains(nameof(FaultyActivity.MethodThatThrowsNestedException), taskFailedException.FailureDetails.StackTrace); + + Assert.NotNull(taskFailedException.FailureDetails.InnerFailure); + Assert.Equal("KABOOM", taskFailedException.FailureDetails.InnerFailure!.ErrorMessage); + Assert.Contains(nameof(FaultyActivity.MethodThatThrowsException), taskFailedException.FailureDetails.InnerFailure.StackTrace); + } + } + catch (Exception e) + { + throw new Exception($"wrong exception thrown", e); + } + + return "ok"; + } +} diff --git a/samples/DtsPortableSdkEntityTests/tests/CallFaultySuborchestration.cs b/samples/DtsPortableSdkEntityTests/tests/CallFaultySuborchestration.cs new file mode 100644 index 000000000..cbde591d2 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/CallFaultySuborchestration.cs @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using DurableTask.Core.Entities; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class CallFaultySuborchestration : Test +{ + // this is not an entity test... but it's a good place to put this test + + private readonly bool nested; + + public CallFaultySuborchestration(bool nested) + { + this.nested = nested; + } + + public override string Name => $"{base.Name}.{(this.nested ? "Nested" : "NotNested")}"; + + public override async Task RunAsync(TestContext context) + { + string orchestrationName = nameof(CallFaultySuborchestrationOrchestration); + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(orchestrationName, this.nested); + var metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, true); + + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("ok", metadata.ReadOutputAs()); + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + registry.AddOrchestrator(); + } + + class FaultySuborchestration : TaskOrchestrator + { + public override Task RunAsync(TaskOrchestrationContext context, bool nested) + { + if (!nested) + { + this.MethodThatThrowsException(); + } + else + { + this.MethodThatThrowsNestedException(); + } + + return Task.FromResult("unreachable"); + } + + public void MethodThatThrowsNestedException() + { + try + { + this.MethodThatThrowsException(); + } + catch (Exception e) + { + throw new Exception("KABOOOOOM", e); + } + } + + public void MethodThatThrowsException() + { + throw new Exception("KABOOM"); + } + } + + class CallFaultySuborchestrationOrchestration : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, bool nested) + { + try + { + await context.CallSubOrchestratorAsync(nameof(FaultySuborchestration), nested); + throw new Exception("expected suborchestrator to throw exception, but none was thrown"); + } + catch (TaskFailedException taskFailedException) + { + Assert.NotNull(taskFailedException.FailureDetails); + + if (!nested) + { + Assert.Equal("KABOOM", taskFailedException.FailureDetails.ErrorMessage); + Assert.Contains(nameof(FaultySuborchestration.MethodThatThrowsException), taskFailedException.FailureDetails.StackTrace); + } + else + { + Assert.Equal("KABOOOOOM", taskFailedException.FailureDetails.ErrorMessage); + Assert.Contains(nameof(FaultySuborchestration.MethodThatThrowsNestedException), taskFailedException.FailureDetails.StackTrace); + + Assert.NotNull(taskFailedException.FailureDetails.InnerFailure); + Assert.Equal("KABOOM", taskFailedException.FailureDetails.InnerFailure!.ErrorMessage); + Assert.Contains(nameof(FaultySuborchestration.MethodThatThrowsException), taskFailedException.FailureDetails.InnerFailure.StackTrace); + } + } + catch (Exception e) + { + throw new Exception($"wrong exception thrown", e); + } + + return "ok"; + } + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/tests/CallMultipleFaultyEntities.cs b/samples/DtsPortableSdkEntityTests/tests/CallMultipleFaultyEntities.cs new file mode 100644 index 000000000..772866075 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/CallMultipleFaultyEntities.cs @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using DurableTask.Core.Entities; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class CallMultipleFaultyEntities : Test +{ + public override async Task RunAsync(TestContext context) + { + var entityId = new EntityInstanceId(nameof(FaultyEntity), Guid.NewGuid().ToString()); + string orchestrationName = nameof(CallFaultyEntityBatchesOrchestration); + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(orchestrationName, entityId); + var metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, true); + + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("ok", metadata.ReadOutputAs()); + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + } +} + +class CallFaultyEntityBatchesOrchestration : TaskOrchestrator +{ + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId entityId) + { + + // we use this utility function to try to enforce that a bunch of signals is delivered as a single batch. + // This is required for some of the tests here to work, since the batching affects the entity state management. + // The "enforcement" mechanism we use is not 100% failsafe (it still makes timing assumptions about the provider) + // but it should be more reliable than the original version of this test which failed quite frequently, as it was + // simply assuming that signals that are sent at the same time are always processed as a batch. + async Task ProcessSignalBatch(IEnumerable<(string, int?)> signals) + { + // first issue a signal that, when delivered, keeps the entity busy for a split second + await context.Entities.SignalEntityAsync(entityId, "Delay", 0.5); + + // we now need to yield briefly so that the delay signal is sent before the others + await context.CreateTimer(context.CurrentUtcDateTime + TimeSpan.FromMilliseconds(1), CancellationToken.None); + + // now send the signals one by one. These should all arrive and get queued (inside the storage provider) + // while the entity is executing the delay operation. Therefore, after the delay operation finishes, + // all of the signals are processed in a single batch. + foreach ((string operation, int? arg) in signals) + { + await context.Entities.SignalEntityAsync(entityId, operation, arg); + } + } + + try + { + await ProcessSignalBatch(new (string, int?)[] + { + new("Set", 42), // state that survives + new("SetThenThrow", 333), + new("DeleteThenThrow", null), + }); + + Assert.Equal(42, await context.Entities.CallEntityAsync(entityId, "Get")); + + await ProcessSignalBatch(new (string, int?)[] + { + new("Get", null), + new("Set", 42), + new("Delete", null), + new("Set", 43), // state that survives + new("DeleteThenThrow", null), + }); + + Assert.Equal(43, await context.Entities.CallEntityAsync(entityId, "Get")); + + await ProcessSignalBatch(new (string, int?)[] + { + new("Set", 55), // state that survives + new("SetToUnserializable", null), + }); + + + Assert.Equal(55, await context.Entities.CallEntityAsync(entityId, "Get")); + + await ProcessSignalBatch(new (string, int?)[] + { + new("Set", 1), + new("Delete", null), + new("Set", 2), + new("Delete", null), // state that survives + new("SetThenThrow", 333), + }); + + Assert.False(await context.Entities.CallEntityAsync(entityId, "Exists")); + + await ProcessSignalBatch(new (string, int?)[] + { + new("Set", 1), + new("Delete", null), + new("Set", 2), + new("Delete", null), // state that survives + new("SetThenThrow", 333), + }); + + // must have rolled back to non-existing state + Assert.False(await context.Entities.CallEntityAsync(entityId, "Exists")); + + await ProcessSignalBatch(new (string, int?)[] + { + new("Set", 1), + new("SetThenThrow", 333), + new("Set", 2), + new("DeleteThenThrow", null), + new("Delete", null), + new("Set", 3), // state that survives + }); + + Assert.Equal(3, await context.Entities.CallEntityAsync(entityId, "Get")); + + return "ok"; + } + catch (Exception e) + { + return e.ToString(); + } + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/tests/CallSingleFaultyEntity.cs b/samples/DtsPortableSdkEntityTests/tests/CallSingleFaultyEntity.cs new file mode 100644 index 000000000..fa15b2f41 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/CallSingleFaultyEntity.cs @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using DurableTask.Core.Entities; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class CallSingleFaultyEntity : Test +{ + public override async Task RunAsync(TestContext context) + { + var entityId = new EntityInstanceId(nameof(FaultyEntity), Guid.NewGuid().ToString()); + string orchestrationName = nameof(CallFaultyEntityOrchestration); + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(orchestrationName, entityId); + var metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, true); + + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("ok", metadata.ReadOutputAs()); + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + } +} + +class CallFaultyEntityOrchestration : TaskOrchestrator +{ + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId entityId) + { + async Task ExpectOperationExceptionAsync(Task t, EntityInstanceId entityId, string operationName, + string errorMessage, string? errorMethod = null, string? innerErrorMessage = null, string innerErrorMethod = "") + { + try + { + await t; + throw new Exception("expected operation exception, but none was thrown"); + } + catch (EntityOperationFailedException entityException) + { + Assert.Equal(operationName, entityException.OperationName); + Assert.Equal(entityId, entityException.EntityId); + Assert.Contains(errorMessage, entityException.Message); + + Assert.NotNull(entityException.FailureDetails); + Assert.Equal(errorMessage, entityException.FailureDetails.ErrorMessage); + + if (errorMethod != null) + { + Assert.Contains(errorMethod, entityException.FailureDetails.StackTrace); + } + + if (innerErrorMessage != null) + { + Assert.NotNull(entityException.FailureDetails.InnerFailure); + Assert.Equal(innerErrorMessage, entityException.FailureDetails.InnerFailure!.ErrorMessage); + + if (innerErrorMethod != null) + { + Assert.Contains(innerErrorMethod, entityException.FailureDetails.InnerFailure.StackTrace); + } + } + else + { + Assert.Null(entityException.FailureDetails.InnerFailure); + } + } + catch (Exception e) + { + throw new Exception($"wrong exception thrown", e); + } + } + + try + { + Assert.False(await context.Entities.CallEntityAsync(entityId, "Exists")); + + await ExpectOperationExceptionAsync( + context.Entities.CallEntityAsync(entityId, "Throw"), + entityId, + "Throw", + "KABOOM", + "ThrowTestException"); + + await ExpectOperationExceptionAsync( + context.Entities.CallEntityAsync(entityId, "ThrowNested"), + entityId, + "ThrowNested", + "KABOOOOOM", + "FaultyEntity.RunAsync", + "KABOOM", + "ThrowTestException"); + + await ExpectOperationExceptionAsync( + context.Entities.CallEntityAsync(entityId, "SetToUnserializable"), + entityId, + "SetToUnserializable", + "problematic object: is not serializable", + "ProblematicObjectJsonConverter.Write"); + + // since the operations failed, the entity state is unchanged, meaning the entity still does not exist + Assert.False(await context.Entities.CallEntityAsync(entityId, "Exists")); + + await context.Entities.CallEntityAsync(entityId, "SetToUndeserializable"); + + Assert.True(await context.Entities.CallEntityAsync(entityId, "Exists")); + + await ExpectOperationExceptionAsync( + context.Entities.CallEntityAsync(entityId, "Get"), + entityId, + "Get", + "problematic object: is not deserializable", + "ProblematicObjectJsonConverter.Read"); + + await context.Entities.CallEntityAsync(entityId, "DeleteWithoutReading"); + + Assert.False(await context.Entities.CallEntityAsync(entityId, "Exists")); + + await context.Entities.CallEntityAsync(entityId, "Set", 3); + + Assert.Equal(3, await context.Entities.CallEntityAsync(entityId, "Get")); + + await ExpectOperationExceptionAsync( + context.Entities.CallEntityAsync(entityId, "SetThenThrow", 333), + entityId, + "SetThenThrow", + "KABOOM", + "FaultyEntity.RunAsync"); + + + // value should be unchanged + Assert.Equal(3, await context.Entities.CallEntityAsync(entityId, "Get")); + + await ExpectOperationExceptionAsync( + context.Entities.CallEntityAsync(entityId, "DeleteThenThrow"), + entityId, + "DeleteThenThrow", + "KABOOM", + "FaultyEntity.RunAsync"); + + // value should be unchanged + Assert.Equal(3, await context.Entities.CallEntityAsync(entityId, "Get")); + + await context.Entities.CallEntityAsync(entityId, "Delete"); + + // entity was deleted + Assert.False(await context.Entities.CallEntityAsync(entityId, "Exists")); + + await ExpectOperationExceptionAsync( + context.Entities.CallEntityAsync(entityId, "SetThenThrow", 333), + entityId, + "SetThenThrow", + "KABOOM", + "FaultyEntity.RunAsync"); + + // must have rolled back to non-existing state + Assert.False(await context.Entities.CallEntityAsync(entityId, "Exists")); + + return "ok"; + } + catch (Exception e) + { + return e.ToString(); + } + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/tests/DeleteAfterLock.cs b/samples/DtsPortableSdkEntityTests/tests/DeleteAfterLock.cs new file mode 100644 index 000000000..3ab00a071 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/DeleteAfterLock.cs @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class DeleteAfterLock : Test +{ + public override async Task RunAsync(TestContext context) + { + // ----- first, delete all already-existing entities in storage to ensure queries have predictable results + context.Logger.LogInformation("deleting existing entities"); + + // run a purge to force a flush, otherwise our query may miss some results + await context.Client.PurgeAllInstancesAsync(new PurgeInstancesFilter() { CreatedFrom = DateTime.MinValue }, context.CancellationToken); + + List tasks = []; + await foreach (var entity in context.Client.Entities.GetAllEntitiesAsync(new EntityQuery())) + { + tasks.Add(context.Client.PurgeInstanceAsync(entity.Id.ToString(), context.CancellationToken)); + } + await Task.WhenAll(tasks); + + // check that a blank entity query returns no elements now + var e = context.Client.Entities.GetAllEntitiesAsync(new EntityQuery()).GetAsyncEnumerator(); + Assert.False(await e.MoveNextAsync()); + + // -------------- then, lock an entity without ever creating state ... so it should disappear afterwards + + var entityId = new EntityInstanceId(nameof(Counter), $"delete-after-lock-{Guid.NewGuid()}"); + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(nameof(LockEntityWithoutCallOrchestration), entityId); + var metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, true); + + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("ok", metadata.ReadOutputAs()); + + // check that entity state is correctly reported as non-existing + EntityMetadata? entityMetadata = await context.Client.Entities.GetEntityAsync(entityId, includeState: true); + Assert.Null(entityMetadata); + + // check that the entity shows up as a transient entity if it has not been automatically deleted + var list = context.Client.Entities!.GetAllEntitiesAsync(new EntityQuery + { + InstanceIdStartsWith = entityId.ToString(), + IncludeTransient = true, + }).ToBlockingEnumerable().ToList(); + + if (!context.BackendSupportsImplicitEntityDeletion) + { + Assert.Single(list); + var cleaningResponse = await context.Client.Entities.CleanEntityStorageAsync(); + Assert.Equal(1, cleaningResponse.EmptyEntitiesRemoved); + } + else + { + Assert.Empty(list); + } + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + } + + public class LockEntityWithoutCallOrchestration : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId entityId) + { + await using (var lockContext = await context.Entities.LockEntitiesAsync(entityId)) + { + // don't do anything with the lock, we only lock the entity but don't create state + }; + + return "ok"; + } + } +} diff --git a/samples/DtsPortableSdkEntityTests/tests/EntityQueries1.cs b/samples/DtsPortableSdkEntityTests/tests/EntityQueries1.cs new file mode 100644 index 000000000..741a49dfc --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/EntityQueries1.cs @@ -0,0 +1,242 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Mime; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class EntityQueries1 : Test +{ + public override async Task RunAsync(TestContext context) + { + // ----- first, delete all already-existing entities in storage to ensure queries have predictable results + context.Logger.LogInformation("deleting existing entities"); + + // run a purge to force a flush, otherwise our query may miss some results + await context.Client.PurgeAllInstancesAsync(new PurgeInstancesFilter() { CreatedFrom = DateTime.MinValue }, context.CancellationToken); + + List tasks = []; + await foreach (var entity in context.Client.Entities.GetAllEntitiesAsync(new EntityQuery())) + { + tasks.Add(context.Client.PurgeInstanceAsync(entity.Id.ToString(), context.CancellationToken)); + } + await Task.WhenAll(tasks); + + // check that a blank entity query returns no elements now + var e = context.Client.Entities.GetAllEntitiesAsync(new EntityQuery()).GetAsyncEnumerator(); + Assert.False(await e.MoveNextAsync()); + + var yesterday = DateTime.UtcNow.Subtract(TimeSpan.FromDays(1)); + var tomorrow = DateTime.UtcNow.Add(TimeSpan.FromDays(1)); + + // ----- next, run a number of orchestrations in order to create specific instances + context.Logger.LogInformation("creating entities"); + + List entityIds = new List() + { + new EntityInstanceId("StringStore", "foo"), + new EntityInstanceId("StringStore", "bar"), + new EntityInstanceId("StringStore", "baz"), + new EntityInstanceId("StringStore2", "foo"), + }; + + await Parallel.ForEachAsync( + Enumerable.Range(0, entityIds.Count), + context.CancellationToken, + async (int i, CancellationToken cancellation) => + { + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(nameof(SignalAndCall.SignalAndCallOrchestration), entityIds[i]); + await context.Client.WaitForInstanceCompletionAsync(instanceId, cancellation); + }); + + // ----- to more easily read this, we first create a collection of (query, validation function) pairs + context.Logger.LogInformation("starting query tests"); + + var tests = new (EntityQuery query, Action> test)[] + { + (new EntityQuery + { + InstanceIdStartsWith = "StringStore", + }, + result => + { + Assert.Equal(4, result.Count()); + }), + + (new EntityQuery + { + InstanceIdStartsWith = "@StringStore", + }, + result => + { + Assert.Equal(4, result.Count()); + }), + + (new EntityQuery + { + InstanceIdStartsWith = "@stringstore", + }, + result => + { + Assert.Equal(4, result.Count()); + }), + + (new EntityQuery + { + InstanceIdStartsWith = "@StringStore@", + }, + result => + { + Assert.Equal(3, result.Count()); + }), + + (new EntityQuery + { + InstanceIdStartsWith = "StringStore@", + }, + result => + { + Assert.Equal(3, result.Count()); + }), + + + (new EntityQuery + { + InstanceIdStartsWith = "@StringStore@foo", + }, + result => + { + Assert.Single(result); + Assert.True(result[0].IncludesState); + }), + + (new EntityQuery + { + InstanceIdStartsWith = "@StringStore@foo", + IncludeState = false, + }, + result => + { + Assert.Single(result); + Assert.False(result[0].IncludesState); + }), + + (new EntityQuery + { + InstanceIdStartsWith = "@StringStore@", + LastModifiedFrom = yesterday, + LastModifiedTo = tomorrow, + }, + result => + { + Assert.Equal(3, result.Count); + }), + + (new EntityQuery + { + InstanceIdStartsWith = "StringStore@ba", + LastModifiedFrom = yesterday, + LastModifiedTo = tomorrow, + }, + result => + { + Assert.Equal(2, result.Count); + }), + + (new EntityQuery + { + InstanceIdStartsWith = "stringstore@BA", + LastModifiedFrom = yesterday, + LastModifiedTo = tomorrow, + }, + result => + { + Assert.Empty(result); + }), + + (new EntityQuery + { + InstanceIdStartsWith = "@StringStore@ba", + LastModifiedFrom = yesterday, + LastModifiedTo = tomorrow, + }, + result => + { + Assert.Equal(2, result.Count); + }), + + (new EntityQuery + { + InstanceIdStartsWith = "@stringstore@BA", + LastModifiedFrom = yesterday, + LastModifiedTo = tomorrow, + }, + result => + { + Assert.Empty(result); + }), + + (new EntityQuery + { + InstanceIdStartsWith = "@StringStore@", + PageSize = 2, + }, + result => + { + Assert.Equal(3, result.Count()); + }), + + (new EntityQuery + { + InstanceIdStartsWith = "@noResult", + LastModifiedFrom = yesterday, + LastModifiedTo = tomorrow, + }, + result => + { + Assert.Empty(result); + }), + + (new EntityQuery + { + LastModifiedFrom = tomorrow, + }, + result => + { + Assert.Empty(result); + }), + + (new EntityQuery + { + LastModifiedTo = yesterday, + }, + result => + { + Assert.Empty(result); + }), + + }; + + foreach (var item in tests) + { + List results = new List(); + await foreach (var element in context.Client.Entities.GetAllEntitiesAsync(item.query)) + { + results.Add(element); + } + + item.test(results); + } + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/tests/EntityQueries2.cs b/samples/DtsPortableSdkEntityTests/tests/EntityQueries2.cs new file mode 100644 index 000000000..d59ceb79a --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/EntityQueries2.cs @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Mime; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class EntityQueries2 : Test +{ + public override async Task RunAsync(TestContext context) + { + // ----- first, delete all already-existing entities in storage to ensure queries have predictable results + context.Logger.LogInformation("deleting existing entities"); + + // run a purge to force a flush, otherwise our query may miss some results + await context.Client.PurgeAllInstancesAsync(new PurgeInstancesFilter() { CreatedFrom = DateTime.MinValue }, context.CancellationToken); + + List tasks = []; + await foreach (var entity in context.Client.Entities.GetAllEntitiesAsync(new EntityQuery())) + { + tasks.Add(context.Client.PurgeInstanceAsync(entity.Id.ToString(), context.CancellationToken)); + } + await Task.WhenAll(tasks); + + // check that a blank entity query returns no elements now + var e = context.Client.Entities.GetAllEntitiesAsync(new EntityQuery()).GetAsyncEnumerator(); + Assert.False(await e.MoveNextAsync()); + + // ----- next, run a number of orchestrations in order to create and/or delete specific instances + context.Logger.LogInformation("creating and deleting entities"); + + List orchestrations = new List() + { + nameof(SignalAndCall.SignalAndCallOrchestration), + nameof(CallAndDelete.CallAndDeleteOrchestration), + nameof(SignalAndCall.SignalAndCallOrchestration), + nameof(CallAndDelete.CallAndDeleteOrchestration), + nameof(SignalAndCall.SignalAndCallOrchestration), + nameof(CallAndDelete.CallAndDeleteOrchestration), + nameof(SignalAndCall.SignalAndCallOrchestration), + nameof(CallAndDelete.CallAndDeleteOrchestration), + }; + + List entityIds = new List() + { + new EntityInstanceId("StringStore", "foo"), + new EntityInstanceId("StringStore2", "bar"), + new EntityInstanceId("StringStore2", "baz"), + new EntityInstanceId("StringStore2", "foo"), + new EntityInstanceId("StringStore2", "ffo"), + new EntityInstanceId("StringStore2", "zzz"), + new EntityInstanceId("StringStore2", "aaa"), + new EntityInstanceId("StringStore2", "bbb"), + }; + + await Parallel.ForEachAsync( + Enumerable.Range(0, entityIds.Count), + context.CancellationToken, + async (int i, CancellationToken cancellation) => + { + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(orchestrations[i], entityIds[i]); + await context.Client.WaitForInstanceCompletionAsync(instanceId, cancellation); + }); + + await Task.Delay(TimeSpan.FromSeconds(3)); // accounts for delay in updating instance tables + + // ----- use a collection of (query, validation function) pairs + context.Logger.LogInformation("starting query tests"); + + var tests = new (EntityQuery query, Action> test)[] + { + (new EntityQuery + { + }, + result => + { + Assert.Equal(4, result.Count()); + }), + + (new EntityQuery + { + IncludeTransient = true, + }, + result => + { + Assert.Equal(context.BackendSupportsImplicitEntityDeletion ? 4 : 8, result.Count()); + }), + + (new EntityQuery + { + PageSize = 3, + }, + result => + { + Assert.Equal(4, result.Count()); + }), + + (new EntityQuery + { + IncludeTransient = true, + PageSize = 3, + }, + result => + { + Assert.Equal(context.BackendSupportsImplicitEntityDeletion ? 4 : 8, result.Count()); // TODO this is provider-specific + }), + }; + + foreach (var item in tests) + { + List results = new List(); + await foreach (var element in context.Client.Entities.GetAllEntitiesAsync(item.query)) + { + results.Add(element); + } + + item.test(results); + } + + // ----- remove the 4 deleted entities whose metadata still lingers in Azure Storage provider + + context.Logger.LogInformation("starting storage cleaning"); + + var cleaningResponse = await context.Client.Entities.CleanEntityStorageAsync(); + + Assert.Equal(context.BackendSupportsImplicitEntityDeletion ? 0 : 4, cleaningResponse.EmptyEntitiesRemoved); + Assert.Equal(0, cleaningResponse.OrphanedLocksReleased); + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/tests/FaultyCriticalSection.cs b/samples/DtsPortableSdkEntityTests/tests/FaultyCriticalSection.cs new file mode 100644 index 000000000..ea858d59f --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/FaultyCriticalSection.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using DurableTask.Core.Entities; +using Microsoft.AspNetCore.Components.Forms; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class FaultyCriticalSection : Test +{ + public override async Task RunAsync(TestContext context) + { + var entityId = new EntityInstanceId(nameof(Counter), Guid.NewGuid().ToString()); + string orchestrationName = nameof(FaultyCriticalSectionOrchestration); + + // run the critical section but fail in the middle + { + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(orchestrationName, new FaultyCriticalSectionOrchestration.Input(entityId, true)); + var metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, getInputsAndOutputs:true); + Assert.Equal(OrchestrationRuntimeStatus.Failed, metadata.RuntimeStatus); + Assert.NotNull(metadata.FailureDetails); + Assert.Equal("KABOOM", metadata.FailureDetails.ErrorMessage); + } + + // run the critical section again without failing this time - this will time out if lock was not released properly. + { + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(orchestrationName, new FaultyCriticalSectionOrchestration.Input(entityId, false)); + var metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, true); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("ok", metadata.ReadOutputAs()); + } + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + } +} + +class FaultyCriticalSectionOrchestration : TaskOrchestrator +{ + public record Input(EntityInstanceId EntityInstanceId, bool Fail); + + public override async Task RunAsync(TaskOrchestrationContext context, Input input) + { + await using (await context.Entities.LockEntitiesAsync(input.EntityInstanceId)) + { + if (input.Fail) + { + throw new Exception("KABOOM"); + } + } + + return "ok"; + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/tests/FireAndForget.cs b/samples/DtsPortableSdkEntityTests/tests/FireAndForget.cs new file mode 100644 index 000000000..38b274118 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/FireAndForget.cs @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Mime; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +/// +/// Scenario that starts a new orchestration from an entity. +/// +class FireAndForget : Test +{ + private readonly int? delay; + + public FireAndForget(int? delay) + { + this.delay = delay; + } + + public override string Name => $"{base.Name}.{(this.delay.HasValue ? "Delay" + this.delay.Value.ToString() : "NoDelay")}"; + + public override async Task RunAsync(TestContext context) + { + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(nameof(LaunchOrchestrationFromEntity), this.delay, context.CancellationToken); + OrchestrationMetadata metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, getInputsAndOutputs: true, context.CancellationToken); + + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + + string? signallingOrchestrationInstanceId = metadata.ReadOutputAs(); + Assert.NotNull(signallingOrchestrationInstanceId); + var launchedMetadata = await context.Client.GetInstanceAsync(signallingOrchestrationInstanceId!, getInputsAndOutputs: true, context.CancellationToken); + Assert.NotNull(launchedMetadata); + Assert.Equal(OrchestrationRuntimeStatus.Completed, launchedMetadata!.RuntimeStatus); + Assert.Equal("ok", launchedMetadata!.ReadOutputAs()); + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + registry.AddOrchestrator(); + + } + + public class LaunchOrchestrationFromEntity : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, int? delay) + { + var entityId = new EntityInstanceId("Launcher", context.NewGuid().ToString().Substring(0, 8)); + + if (delay.HasValue) + { + await context.Entities.CallEntityAsync(entityId, "launch", context.CurrentUtcDateTime + TimeSpan.FromSeconds(delay.Value)); + } + else + { + await context.Entities.CallEntityAsync(entityId, "launch"); + } + + while (true) + { + string? signallingOrchestrationId = await context.Entities.CallEntityAsync(entityId, "get"); + + if (signallingOrchestrationId != null) + { + return signallingOrchestrationId; + } + + await context.CreateTimer(DateTime.UtcNow + TimeSpan.FromSeconds(1), CancellationToken.None); + } + } + } + + public class SignallingOrchestration : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId entityId) + { + await context.CreateTimer(DateTime.UtcNow + TimeSpan.FromSeconds(.2), CancellationToken.None); + + await context.Entities.SignalEntityAsync(entityId, "done"); + + // to test replay, we add a little timer + await context.CreateTimer(DateTime.UtcNow + TimeSpan.FromMilliseconds(1), CancellationToken.None); + + return "ok"; + } + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/tests/InvalidEntityId.cs b/samples/DtsPortableSdkEntityTests/tests/InvalidEntityId.cs new file mode 100644 index 000000000..de5d2592c --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/InvalidEntityId.cs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Entities; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +/// +/// This test is not entity related, but discovered an issue with how failures in orchestrators are captured. +/// +class InvalidEntityId : Test +{ + public enum Location + { + ClientSignal, + ClientGet, + OrchestrationSignal, + OrchestrationCall, + } + + readonly Location location; + + public InvalidEntityId(Location location) + { + this.location = location; + } + + public override string Name => $"{base.Name}.{this.location}"; + + public override async Task RunAsync(TestContext context) + { + switch (this.location) + { + case Location.ClientSignal: + await Assert.ThrowsAsync( + async () => + { + await context.Client.Entities.SignalEntityAsync(default, "add", 1); + }); + return; + + case Location.ClientGet: + await Assert.ThrowsAsync( + async () => + { + await context.Client.Entities.GetEntityAsync(default); + }); + return; + + case Location.OrchestrationSignal: + { + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(nameof(SignalAndCall.SignalAndCallOrchestration) /* missing input */); + var metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, true); + Assert.Equal(OrchestrationRuntimeStatus.Failed, metadata.RuntimeStatus); + //Assert.NotNull(metadata.FailureDetails); // TODO currently failing because FailureDetails are not propagated for some reason + } + break; + + case Location.OrchestrationCall: + { + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(nameof(CallCounter.CallCounterOrchestration) /* missing input */); + var metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, true); + Assert.Equal(OrchestrationRuntimeStatus.Failed, metadata.RuntimeStatus); + //Assert.NotNull(metadata.FailureDetails); // TODO currently failing because FailureDetails are not propagated for some reason + } + break; + } + } +} diff --git a/samples/DtsPortableSdkEntityTests/tests/LargeEntity.cs b/samples/DtsPortableSdkEntityTests/tests/LargeEntity.cs new file mode 100644 index 000000000..44eb54d0e --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/LargeEntity.cs @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +/// +/// validates a simple entity scenario where an entity's state is +/// larger than what fits into Azure table rows. +/// +internal class LargeEntity : Test +{ + public override async Task RunAsync(TestContext context) + { + var entityId = new EntityInstanceId(nameof(StringStore2), Guid.NewGuid().ToString().Substring(0, 8)); + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(nameof(LargeEntityOrchestration), entityId); + + // wait for completion of the orchestration + { + var metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, true); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("ok", metadata.ReadOutputAs()); + } + + // read untyped without including state + { + EntityMetadata? metadata = await context.Client.Entities.GetEntityAsync(entityId, includeState: false, context.CancellationToken); + Assert.NotNull(metadata); + Assert.Throws(() => metadata!.State); + } + + // read untyped including state + { + EntityMetadata? metadata = await context.Client.Entities.GetEntityAsync(entityId, includeState: true, context.CancellationToken); + Assert.NotNull(metadata); + Assert.NotNull(metadata!.State); + Assert.Equal(100000, metadata!.State.ReadAs().Length); + } + + // read typed without including state + { + EntityMetadata? metadata = await context.Client.Entities.GetEntityAsync(entityId, includeState: false, context.CancellationToken); + Assert.NotNull(metadata); + Assert.Throws(() => metadata!.State); + } + + // read typed including state + { + EntityMetadata? metadata = await context.Client.Entities.GetEntityAsync(entityId, includeState: true, context.CancellationToken); + Assert.NotNull(metadata); + Assert.NotNull(metadata!.State); + Assert.Equal(100000, metadata!.State.Length); + } + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + } + + public class LargeEntityOrchestration : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId entityId) + { + string content = new string('.', 100000); + + await context.Entities.CallEntityAsync(entityId, "set", content); + + var result = await context.Entities.CallEntityAsync(entityId, "get"); + + if (result != content) + { + return $"fail: wrong entity state"; + } + + return "ok"; + } + } +} diff --git a/samples/DtsPortableSdkEntityTests/tests/MultipleLockedTransfers.cs b/samples/DtsPortableSdkEntityTests/tests/MultipleLockedTransfers.cs new file mode 100644 index 000000000..57bdd3531 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/MultipleLockedTransfers.cs @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Mime; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class MultipleLockedTransfers : Test +{ + readonly int numberEntities; + + public MultipleLockedTransfers(int numberEntities) + { + this.numberEntities = numberEntities; + } + + public override string Name => $"{base.Name}.{this.numberEntities}"; + + public override async Task RunAsync(TestContext context) + { + // create specified number of counters + var counters = new EntityInstanceId[this.numberEntities]; + for (int i = 0; i < this.numberEntities; i++) + { + counters[i] = new EntityInstanceId(nameof(Counter), Guid.NewGuid().ToString().Substring(0, 8)); + } + + // in parallel, start one transfer per counter, each decrementing a counter and incrementing + // its successor (where the last one wraps around to the first) + // This is a pattern that would deadlock if we didn't order the lock acquisition. + var instances = new Task[this.numberEntities]; + for (int i = 0; i < this.numberEntities; i++) + { + instances[i] = context.Client.ScheduleNewOrchestrationInstanceAsync( + nameof(SingleLockedTransfer.LockedTransferOrchestration), + new[] { counters[i], counters[(i + 1) % this.numberEntities] }, + context.CancellationToken); + } + await Task.WhenAll(instances); + + + // in parallel, wait for all transfers to complete + var metadata = new Task[this.numberEntities]; + for (int i = 0; i < this.numberEntities; i++) + { + metadata[i] = context.Client.WaitForInstanceCompletionAsync(instances[i].Result, getInputsAndOutputs: true, context.CancellationToken); + } + await Task.WhenAll(metadata); + + // check that they all completed + for (int i = 0; i < this.numberEntities; i++) + { + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata[i].Result.RuntimeStatus); + } + + // in parallel, read all the entity states + var entityMetadata = new Task?>[this.numberEntities]; + for (int i = 0; i < this.numberEntities; i++) + { + entityMetadata[i] = context.Client.Entities.GetEntityAsync(counters[i], includeState: true, context.CancellationToken); + } + await Task.WhenAll(entityMetadata); + + // check that the counter states are all back to 0 + // (since each participated in 2 transfers, one incrementing and one decrementing) + for (int i = 0; i < numberEntities; i++) + { + EntityMetadata? response = entityMetadata[i].Result; + Assert.NotNull(response); + Assert.Equal(0, response!.State); + } + } +} diff --git a/samples/DtsPortableSdkEntityTests/tests/NoOrphanedLockAfterNondeterminism.cs b/samples/DtsPortableSdkEntityTests/tests/NoOrphanedLockAfterNondeterminism.cs new file mode 100644 index 000000000..7b9933374 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/NoOrphanedLockAfterNondeterminism.cs @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Mime; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Xunit; + + +namespace DtsPortableSdkEntityTests; + +class NoOrphanedLockAfterNondeterminism : Test +{ + + public override async Task RunAsync(TestContext context) + { + DateTime startTime = DateTime.UtcNow; + + // construct unique names for this test + string prefix = Guid.NewGuid().ToString("N").Substring(0, 6); + var orphanedEntityId = new EntityInstanceId(nameof(Counter), $"{prefix}-orphaned"); + var orchestrationA = $"{prefix}-A"; + var orchestrationB = $"{prefix}-B"; + + // start an orchestration A that acquires the lock and then throws a nondeterminism error + await context.Client.ScheduleNewOrchestrationInstanceAsync( + nameof(NondeterministicLocker), + orphanedEntityId, + new StartOrchestrationOptions() { InstanceId = orchestrationA }, + context.CancellationToken); + await context.Client.WaitForInstanceStartAsync(orchestrationA, context.CancellationToken); + + // start an orchestration B that queues behind A for the lock + await context.Client.ScheduleNewOrchestrationInstanceAsync( + nameof(LockingIncrementor2), + orphanedEntityId, + new StartOrchestrationOptions() { InstanceId = orchestrationB }, + context.CancellationToken); + await context.Client.WaitForInstanceStartAsync(orchestrationB, context.CancellationToken); + + // wait for orchestration B to finish + OrchestrationMetadata metadata = await context.Client.WaitForInstanceCompletionAsync(orchestrationB, getInputsAndOutputs: true, context.CancellationToken); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("ok", metadata.ReadOutputAs()); + + // check that orchestration A reported nondeterminism + metadata = await context.Client.WaitForInstanceCompletionAsync(orchestrationA, getInputsAndOutputs: true, context.CancellationToken); + Assert.Equal(OrchestrationRuntimeStatus.Failed, metadata.RuntimeStatus); + Assert.Contains("Non-Deterministic workflow detected", metadata.FailureDetails?.ErrorMessage); + + // check the status of the entity to confirm that the lock is no longer held + EntityMetadata? entityMetadata = await context.Client.Entities.GetEntityAsync(orphanedEntityId, context.CancellationToken); + Assert.NotNull(entityMetadata); + Assert.Equal(orphanedEntityId, entityMetadata.Id); + Assert.True(entityMetadata.IncludesState); + Assert.Equal(1, entityMetadata.State.ReadAs()); + Assert.True(entityMetadata.LastModifiedTime > startTime); + Assert.Null(entityMetadata.LockedBy); + Assert.Equal(0, entityMetadata.BacklogQueueSize); + + // purge instances from storage + PurgeResult purgeResult = await context.Client.PurgeInstanceAsync(orchestrationA); + Assert.Equal(1, purgeResult.PurgedInstanceCount); + purgeResult = await context.Client.PurgeInstanceAsync(orchestrationB); + Assert.Equal(1, purgeResult.PurgedInstanceCount); + purgeResult = await context.Client.PurgeInstanceAsync(orphanedEntityId.ToString()); + Assert.Equal(1, purgeResult.PurgedInstanceCount); + + // test that purge worked + purgeResult = await context.Client.PurgeInstanceAsync(orchestrationA); + Assert.Equal(0, purgeResult.PurgedInstanceCount); + purgeResult = await context.Client.PurgeInstanceAsync(orchestrationB); + Assert.Equal(0, purgeResult.PurgedInstanceCount); + purgeResult = await context.Client.PurgeInstanceAsync(orphanedEntityId.ToString()); + Assert.Equal(0, purgeResult.PurgedInstanceCount); + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + registry.AddOrchestrator(); + } + + class NondeterministicLocker : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId entityId) + { + if (!context.IsReplaying) // replay will encounter nondeterminism before replaying the lock + { + await context.Entities.LockEntitiesAsync(entityId); + } + + return "nondeterminstic"; + } + } + + class LockingIncrementor2 : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId entityId) + { + await using (await context.Entities.LockEntitiesAsync(entityId)) + { + await context.Entities.CallEntityAsync(entityId, "increment"); + + // we got the entity + return "ok"; + } + } + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/tests/NoOrphanedLockAfterTermination.cs b/samples/DtsPortableSdkEntityTests/tests/NoOrphanedLockAfterTermination.cs new file mode 100644 index 000000000..ae67b9156 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/NoOrphanedLockAfterTermination.cs @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Mime; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Xunit; + + +namespace DtsPortableSdkEntityTests; + +class NoOrphanedLockAfterTermination : Test +{ + + public override async Task RunAsync(TestContext context) + { + DateTime startTime = DateTime.UtcNow; + + // construct unique names for this test + string prefix = Guid.NewGuid().ToString("N").Substring(0, 6); + var orphanedEntityId = new EntityInstanceId(nameof(Counter), $"{prefix}-orphaned"); + var orchestrationA = $"{prefix}-A"; + var orchestrationB = $"{prefix}-B"; + + // start an orchestration A that acquires the lock and then waits forever + await context.Client.ScheduleNewOrchestrationInstanceAsync( + nameof(InfiniteLocker), + orphanedEntityId, + new StartOrchestrationOptions() { InstanceId = orchestrationA }, + context.CancellationToken); + await context.Client.WaitForInstanceStartAsync(orchestrationA, context.CancellationToken); + + // start an orchestration B that queues behind A for the lock + await context.Client.ScheduleNewOrchestrationInstanceAsync( + nameof(LockingIncrementor), + orphanedEntityId, + new StartOrchestrationOptions() { InstanceId = orchestrationB }, + context.CancellationToken); + await context.Client.WaitForInstanceStartAsync(orchestrationB, context.CancellationToken); + + // try to get the entity using a point query. The result is null because the entitiy is transient. + EntityMetadata? entityMetadata = await context.Client.Entities.GetEntityAsync(orphanedEntityId, context.CancellationToken); + Assert.Null(entityMetadata); + + // try to get the entity state using a query that does not include transient states. SHould not return anything. + List results = context.Client.Entities.GetAllEntitiesAsync( + new Microsoft.DurableTask.Client.Entities.EntityQuery + { + InstanceIdStartsWith = orphanedEntityId.ToString(), + IncludeTransient = false, + IncludeState = true, + }).ToBlockingEnumerable().ToList(); + Assert.Empty(results); + + // try to get the entity state using a query that includes transient states. This should return the entity. + results = context.Client.Entities.GetAllEntitiesAsync( + new Microsoft.DurableTask.Client.Entities.EntityQuery { + InstanceIdStartsWith = orphanedEntityId.ToString(), + IncludeTransient = true, + IncludeState = true, + }).ToBlockingEnumerable().ToList(); + Assert.Single(results); + Assert.Equal(orphanedEntityId, results[0].Id); + Assert.False(results[0].IncludesState); + Assert.True(results[0].LastModifiedTime > startTime); + Assert.Equal(orchestrationA, results[0].LockedBy); + //Assert.Equal(1, results[0].BacklogQueueSize); //TODO implement this + + // check that purge on the entity is rejected (because the entity is locked) + PurgeResult purgeResult = await context.Client.PurgeInstanceAsync(orphanedEntityId.ToString(), context.CancellationToken); + Assert.Equal(0, purgeResult.PurgedInstanceCount); + + // check that purge on the orchestration is rejected (because it is not in a completed state) + purgeResult = await context.Client.PurgeInstanceAsync(orchestrationA.ToString(), context.CancellationToken); + Assert.Equal(0, purgeResult.PurgedInstanceCount); + + // NOW, we terminate orchestration A, which should implicitly release the lock + DateTime terminationTime = DateTime.UtcNow; + await context.Client.TerminateInstanceAsync(orchestrationA, context.CancellationToken); + + // wait for orchestration B to finish + OrchestrationMetadata metadata = await context.Client.WaitForInstanceCompletionAsync(orchestrationB, getInputsAndOutputs: true, context.CancellationToken); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("ok", metadata.ReadOutputAs()); + + // check that orchestration A is reported as terminated + metadata = await context.Client.WaitForInstanceCompletionAsync(orchestrationA, getInputsAndOutputs: true, context.CancellationToken); + Assert.Equal(OrchestrationRuntimeStatus.Terminated, metadata.RuntimeStatus); + + // check the status of the entity to confirm that the lock is no longer held + entityMetadata = await context.Client.Entities.GetEntityAsync(orphanedEntityId, context.CancellationToken); + Assert.NotNull(entityMetadata); + Assert.Equal(orphanedEntityId, entityMetadata.Id); + Assert.True(entityMetadata.IncludesState); + Assert.Equal(1, entityMetadata.State.ReadAs()); + Assert.True(entityMetadata.LastModifiedTime > terminationTime); + Assert.Null(entityMetadata.LockedBy); + Assert.Equal(0, entityMetadata.BacklogQueueSize); + + // same, but using a query + results = context.Client.Entities.GetAllEntitiesAsync( + new Microsoft.DurableTask.Client.Entities.EntityQuery() + { + InstanceIdStartsWith = orphanedEntityId.ToString(), + IncludeTransient = false, + IncludeState = false, + }).ToBlockingEnumerable().ToList(); + Assert.Single(results); + Assert.Equal(orphanedEntityId, results[0].Id); + Assert.False(results[0].IncludesState); + Assert.True(results[0].LastModifiedTime > terminationTime); + Assert.Null(results[0].LockedBy); + Assert.Equal(0, results[0].BacklogQueueSize); + + // purge instances from storage + purgeResult = await context.Client.PurgeInstanceAsync(orchestrationA); + Assert.Equal(1, purgeResult.PurgedInstanceCount); + purgeResult = await context.Client.PurgeInstanceAsync(orchestrationB); + Assert.Equal(1, purgeResult.PurgedInstanceCount); + purgeResult = await context.Client.PurgeInstanceAsync(orphanedEntityId.ToString()); + Assert.Equal(1, purgeResult.PurgedInstanceCount); + + // test that purge worked + purgeResult = await context.Client.PurgeInstanceAsync(orchestrationA); + Assert.Equal(0, purgeResult.PurgedInstanceCount); + purgeResult = await context.Client.PurgeInstanceAsync(orchestrationB); + Assert.Equal(0, purgeResult.PurgedInstanceCount); + purgeResult = await context.Client.PurgeInstanceAsync(orphanedEntityId.ToString()); + Assert.Equal(0, purgeResult.PurgedInstanceCount); + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + registry.AddOrchestrator(); + } + + class InfiniteLocker : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId entityId) + { + await using (await context.Entities.LockEntitiesAsync(entityId)) + { + await context.CreateTimer(DateTime.UtcNow + TimeSpan.FromDays(365), CancellationToken.None); + } + + // will never reach the end here because we get purged in the middle + return "ok"; + } + } + + class LockingIncrementor : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId entityId) + { + await using (await context.Entities.LockEntitiesAsync(entityId)) + { + await context.Entities.CallEntityAsync(entityId, "increment"); + + // we got the entity + return "ok"; + } + } + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/tests/SelfScheduling.cs b/samples/DtsPortableSdkEntityTests/tests/SelfScheduling.cs new file mode 100644 index 000000000..330dac9cc --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/SelfScheduling.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Mime; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Entities; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class SelfScheduling : Test +{ + public override async Task RunAsync(TestContext context) + { + var entityId = new EntityInstanceId(nameof(SelfSchedulingEntity), Guid.NewGuid().ToString().Substring(0,8)); + + await context.Client.Entities.SignalEntityAsync(entityId, "start"); + + var result = await context.WaitForEntityStateAsync( + entityId, + timeout: default, + entityState => entityState.Value.Length == 4 ? null : "expect 4 letters"); + + Assert.NotNull(result); + Assert.Equal("ABCD", result.Value); + } +} diff --git a/samples/DtsPortableSdkEntityTests/tests/SetAndGet.cs b/samples/DtsPortableSdkEntityTests/tests/SetAndGet.cs new file mode 100644 index 000000000..8f665e9e9 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/SetAndGet.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Mime; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class SetAndGet : Test +{ + public override async Task RunAsync(TestContext context) + { + var entityId = new EntityInstanceId(nameof(Counter), Guid.NewGuid().ToString()); + + // entity should not yet exist + EntityMetadata? result = await context.Client.Entities.GetEntityAsync(entityId); + Assert.Null(result); + + // entity should still not exist + result = await context.Client.Entities.GetEntityAsync(entityId, includeState:true); + Assert.Null(result); + + // send one signal + await context.Client.Entities.SignalEntityAsync(entityId, "Set", 1); + + // wait for state + int state = await context.WaitForEntityStateAsync(entityId); + Assert.Equal(1, state); + + // if we query the entity state again it should still be the same + result = await context.Client.Entities.GetEntityAsync(entityId); + + Assert.NotNull(result); + Assert.Equal(1,result!.State); + } +} diff --git a/samples/DtsPortableSdkEntityTests/tests/SignalAndCall.cs b/samples/DtsPortableSdkEntityTests/tests/SignalAndCall.cs new file mode 100644 index 000000000..15350666f --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/SignalAndCall.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Mime; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class SignalAndCall : Test +{ + readonly Type entityType; + + public SignalAndCall(Type entityType) + { + this.entityType = entityType; + } + + public override string Name => $"{base.Name}.{entityType.Name}"; + + public override async Task RunAsync(TestContext context) + { + var entityId = new EntityInstanceId(this.entityType.Name, Guid.NewGuid().ToString().Substring(0, 8)); + + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(nameof(SignalAndCallOrchestration), entityId, context.CancellationToken); + OrchestrationMetadata metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, getInputsAndOutputs:true, context.CancellationToken); + + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("ok", metadata.ReadOutputAs()); + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + } + + public class SignalAndCallOrchestration : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId entity) + { + // signal and call (both of these will be delivered close together, typically in the same batch, and always in order) + await context.Entities.SignalEntityAsync(entity, "set", "333"); + + string? result = await context.Entities.CallEntityAsync(entity, "get"); + + if (result != "333") + { + return $"fail: wrong entity state: expected 333, got {result}"; + } + + // make another call to see if the state survives replay + result = await context.Entities.CallEntityAsync(entity, "get"); + + if (result != "333") + { + return $"fail: wrong entity state: expected 333 still, but got {result}"; + } + + return "ok"; + } + } +} diff --git a/samples/DtsPortableSdkEntityTests/tests/SignalThenPoll.cs b/samples/DtsPortableSdkEntityTests/tests/SignalThenPoll.cs new file mode 100644 index 000000000..2c2976ec0 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/SignalThenPoll.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Entities; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class SignalThenPoll : Test +{ + private readonly bool direct; + private readonly bool delayed; + + public SignalThenPoll(bool direct, bool delayed) + { + this.direct = direct; + this.delayed = delayed; + } + + public override string Name => $"{base.Name}.{(this.direct ? "Direct" : "Indirect")}.{(this.delayed ? "Delayed" : "Immediately")}"; + + public override async Task RunAsync(TestContext context) + { + var counterEntityId = new EntityInstanceId(nameof(Counter), Guid.NewGuid().ToString().Substring(0, 8)); + var relayEntityId = new EntityInstanceId("Relay", ""); + + string pollingInstance = await context.Client.ScheduleNewOrchestrationInstanceAsync(nameof(PollingOrchestration), counterEntityId, context.CancellationToken); + DateTimeOffset? scheduledTime = this.delayed ? DateTime.UtcNow + TimeSpan.FromSeconds(5) : null; + + if (this.direct) + { + await context.Client.Entities.SignalEntityAsync( + counterEntityId, + "set", + 333, + new SignalEntityOptions() { SignalTime = scheduledTime }, + context.CancellationToken); + } + else + { + await context.Client.Entities.SignalEntityAsync( + relayEntityId, + operationName: "", + input: new Relay.Input(counterEntityId, "set", 333, scheduledTime), + options: null, + context.CancellationToken); + } + + var metadata = await context.Client.WaitForInstanceCompletionAsync(pollingInstance, true, context.CancellationToken); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("ok", metadata.ReadOutputAs()); + + if (this.delayed) + { + Assert.True(metadata.LastUpdatedAt > scheduledTime - TimeSpan.FromMilliseconds(100)); + } + + int counterState = await context.WaitForEntityStateAsync( + counterEntityId, + timeout: default); + + Assert.Equal(333, counterState); + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + } + + public class PollingOrchestration : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId entityId) + { + DateTime startTime = context.CurrentUtcDateTime; + + while (context.CurrentUtcDateTime < startTime + TimeSpan.FromSeconds(30)) + { + var result = await context.Entities.CallEntityAsync(entityId, "get"); + + if (result != 0) + { + if (result == 333) + { + return "ok"; + } + else + { + return $"fail: wrong entity state: expected 333, got {result}"; + } + } + + await context.CreateTimer(DateTime.UtcNow + TimeSpan.FromSeconds(1), CancellationToken.None); + } + + return "timed out while waiting for entity to have state"; + } + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/tests/SingleLockedTransfer.cs b/samples/DtsPortableSdkEntityTests/tests/SingleLockedTransfer.cs new file mode 100644 index 000000000..530b488e3 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/SingleLockedTransfer.cs @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Mime; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class SingleLockedTransfer : Test +{ + public override async Task RunAsync(TestContext context) + { + var counter1 = new EntityInstanceId("Counter", Guid.NewGuid().ToString().Substring(0, 8)); + var counter2 = new EntityInstanceId("Counter", Guid.NewGuid().ToString().Substring(0, 8)); + + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(nameof(LockedTransferOrchestration), new[] { counter1, counter2 }, context.CancellationToken); + OrchestrationMetadata metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, getInputsAndOutputs: true, context.CancellationToken); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal(new[] { -1, 1 }, metadata.ReadOutputAs()); + + // validate the state of the counters + EntityMetadata? response1 = await context.Client.Entities.GetEntityAsync(counter1, true, context.CancellationToken); + EntityMetadata? response2 = await context.Client.Entities.GetEntityAsync(counter2, true, context.CancellationToken); + Assert.NotNull(response1); + Assert.NotNull(response2); + Assert.Equal(-1, response1!.State); + Assert.Equal(1, response2!.State); + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + } + + public class LockedTransferOrchestration : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, EntityInstanceId[] entities) + { + var from = entities![0]; + var to = entities![1]; + + if (from.Equals(to)) + { + throw new ArgumentException("from and to must be distinct"); + } + + ExpectSynchState(false); + + int fromBalance; + int toBalance; + + await using (await context.Entities.LockEntitiesAsync(from, to)) + { + ExpectSynchState(true, from, to); + + // read balances in parallel + var t1 = context.Entities.CallEntityAsync(from, "get"); + ExpectSynchState(true, to); + var t2 = context.Entities.CallEntityAsync(to, "get"); + ExpectSynchState(true); + + + fromBalance = await t1; + toBalance = await t2; + ExpectSynchState(true, from, to); + + // modify + fromBalance--; + toBalance++; + + // write balances in parallel + var t3 = context.Entities.CallEntityAsync(from, "set", fromBalance); + ExpectSynchState(true, to); + var t4 = context.Entities.CallEntityAsync(to, "set", toBalance); + ExpectSynchState(true); + await t4; + await t3; + ExpectSynchState(true, to, from); + + } // lock is released here + + ExpectSynchState(false); + + return new int[] { fromBalance, toBalance }; + + void ExpectSynchState(bool inCriticalSection, params EntityInstanceId[] ids) + { + Assert.Equal(inCriticalSection, context.Entities.InCriticalSection(out var currentLocks)); + if (inCriticalSection) + { + Assert.Equal( + ids.Select(i => i.ToString()).OrderBy(s => s), + currentLocks!.Select(i => i.ToString()).OrderBy(s => s)); + } + } + } + } +} \ No newline at end of file diff --git a/samples/DtsPortableSdkEntityTests/tests/TwoCriticalSections.cs b/samples/DtsPortableSdkEntityTests/tests/TwoCriticalSections.cs new file mode 100644 index 000000000..2db9aa2e1 --- /dev/null +++ b/samples/DtsPortableSdkEntityTests/tests/TwoCriticalSections.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Mime; +using System.Text; +using System.Threading.Tasks; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Client; +using Microsoft.DurableTask.Client.Entities; +using Microsoft.DurableTask.Entities; +using Xunit; + +namespace DtsPortableSdkEntityTests; + +class TwoCriticalSections : Test +{ + readonly bool sameEntity; + + public TwoCriticalSections(bool sameEntity) + { + this.sameEntity = sameEntity; + } + + public override string Name => $"{base.Name}.{this.sameEntity}"; + + public override async Task RunAsync(TestContext context) + { + var key1 = Guid.NewGuid().ToString().Substring(0, 8); + var key2 = this.sameEntity ? key1 : Guid.NewGuid().ToString().Substring(0, 8); + + string instanceId = await context.Client.ScheduleNewOrchestrationInstanceAsync(nameof(TwoCriticalSectionsOrchestration), new[] { key1, key2 }, context.CancellationToken); + OrchestrationMetadata metadata = await context.Client.WaitForInstanceCompletionAsync(instanceId, getInputsAndOutputs: true, context.CancellationToken); + Assert.Equal(OrchestrationRuntimeStatus.Completed, metadata.RuntimeStatus); + Assert.Equal("ok", metadata.ReadOutputAs()); + } + + public override void Register(DurableTaskRegistry registry, IServiceCollection services) + { + registry.AddOrchestrator(); + } + + public class TwoCriticalSectionsOrchestration : TaskOrchestrator + { + public override async Task RunAsync(TaskOrchestrationContext context, string[] entityKeys) + { + string key1 = entityKeys![0]; + string key2 = entityKeys![1]; + + await using (await context.Entities.LockEntitiesAsync([new EntityInstanceId(nameof(Counter), key1)])) + { + await context.Entities.CallEntityAsync(new EntityInstanceId(nameof(Counter), key1), "add", 1); + } + await using (await context.Entities.LockEntitiesAsync([new EntityInstanceId(nameof(Counter), key2)])) + { + await context.Entities.CallEntityAsync(new EntityInstanceId(nameof(Counter), key2), "add", 1); + } + + return "ok"; + } + } +} \ No newline at end of file