diff --git a/samples/AzureFunctionsApp/Entities/Lifetime.cs b/samples/AzureFunctionsApp/Entities/Lifetime.cs index 5a219bac2..0b1083ef0 100644 --- a/samples/AzureFunctionsApp/Entities/Lifetime.cs +++ b/samples/AzureFunctionsApp/Entities/Lifetime.cs @@ -75,6 +75,12 @@ public void Delete() protected override MyState InitializeState(TaskEntityOperation operation) { // This method allows for customizing the default state value for a new entity. + // For async initialization (e.g., loading from a database or external storage), override InitializeStateAsync instead: + // protected override async ValueTask InitializeStateAsync(TaskEntityOperation operation) + // { + // var data = await LoadFromDatabaseAsync(); + // return new MyState(data.PropA, data.PropB); + // } return new(Guid.NewGuid().ToString("N"), Random.Shared.Next(0, 1000)); } } diff --git a/src/Abstractions/Entities/TaskEntity.cs b/src/Abstractions/Entities/TaskEntity.cs index d8fee1ee2..87ffcd3a2 100644 --- a/src/Abstractions/Entities/TaskEntity.cs +++ b/src/Abstractions/Entities/TaskEntity.cs @@ -119,25 +119,40 @@ public abstract class TaskEntity : ITaskEntity protected TaskEntityContext Context { get; private set; } = null!; /// - public ValueTask RunAsync(TaskEntityOperation operation) + public async ValueTask RunAsync(TaskEntityOperation operation) { Check.NotNull(operation); this.Context = operation.Context; object? state = operation.State.GetState(typeof(TState)); - this.State = state is null ? this.InitializeState(operation) : (TState)state; + this.State = state is null ? await this.InitializeStateAsync(operation) : (TState)state; if (!operation.TryDispatch(this, out object? result, out Type returnType) && !this.TryDispatchState(operation, out result, out returnType)) { if (TryDispatchImplicit(operation, out ValueTask task)) { // We do not go into UnwrapAsync for implicit tasks - return task; + return await task; } throw new NotSupportedException($"No suitable method found for entity operation '{operation}'."); } - return TaskEntityHelpers.UnwrapAsync(operation.State, () => this.State, result, returnType); + return await TaskEntityHelpers.UnwrapAsync(operation.State, () => this.State, result, returnType); + } + + /// + /// Initializes the entity state asynchronously. This is only called when there is no current state for this entity. + /// + /// The entity operation to be executed. + /// A task that resolves to the entity state. + /// + /// The default implementation calls for backward compatibility. + /// Override this method to perform async initialization operations, such as loading state from a database or + /// external storage. + /// + protected virtual ValueTask InitializeStateAsync(TaskEntityOperation entityOperation) + { + return new ValueTask(this.InitializeState(entityOperation)); } /// diff --git a/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs b/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs index 06d16a4f4..fb2205112 100644 --- a/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs +++ b/test/Abstractions.Tests/Entities/StateTaskEntityTests.cs @@ -172,6 +172,53 @@ public async Task ExplicitDelete_Overridden(string op) operation.State.GetState(typeof(TestState)).Should().BeOfType().Which.Value.Should().Be(0); } + [Fact] + public async Task AsyncInitializeState_Called() + { + TestEntityOperation operation = new("get0", new TestEntityState(null), default); + AsyncInitEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().BeOfType().Which.Should().Be(42); + entity.InitializeAsyncCalled.Should().BeTrue(); + } + + [Fact] + public async Task AsyncInitializeState_WithYield_Succeeds() + { + TestEntityOperation operation = new("get0", new TestEntityState(null), default); + AsyncInitWithYieldEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().BeOfType().Which.Should().Be(99); + entity.InitializeAsyncCalled.Should().BeTrue(); + } + + [Fact] + public async Task AsyncInitializeState_ValueTask_Succeeds() + { + TestEntityOperation operation = new("get0", new TestEntityState(null), default); + AsyncInitValueTaskEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().BeOfType().Which.Should().Be(77); + } + + [Fact] + public async Task SyncInitializeState_StillWorks() + { + TestEntityOperation operation = new("get0", new TestEntityState(null), default); + SyncInitEntity entity = new(); + + object? result = await entity.RunAsync(operation); + + result.Should().BeOfType().Which.Should().Be(55); + entity.SyncInitCalled.Should().BeTrue(); + } + static TestState State(int value) => new() { Value = value }; class NullStateEntity : TestEntity @@ -179,6 +226,49 @@ class NullStateEntity : TestEntity protected override TestState InitializeState(TaskEntityOperation entityOperation) => null!; } + class AsyncInitEntity : TestEntity + { + public bool InitializeAsyncCalled { get; private set; } + + protected override async ValueTask InitializeStateAsync(TaskEntityOperation entityOperation) + { + await Task.CompletedTask; + this.InitializeAsyncCalled = true; + return new TestState { Value = 42 }; + } + } + + class AsyncInitWithYieldEntity : TestEntity + { + public bool InitializeAsyncCalled { get; private set; } + + protected override async ValueTask InitializeStateAsync(TaskEntityOperation entityOperation) + { + await Task.Yield(); + this.InitializeAsyncCalled = true; + return new TestState { Value = 99 }; + } + } + + class AsyncInitValueTaskEntity : TestEntity + { + protected override ValueTask InitializeStateAsync(TaskEntityOperation entityOperation) + { + return new ValueTask(new TestState { Value = 77 }); + } + } + + class SyncInitEntity : TestEntity + { + public bool SyncInitCalled { get; private set; } + + protected override TestState InitializeState(TaskEntityOperation entityOperation) + { + this.SyncInitCalled = true; + return new TestState { Value = 55 }; + } + } + class TestEntity : TaskEntity { readonly bool allowStateDispatch;