diff --git a/src/Proto.TestKit/ITestProbe.cs b/src/Proto.TestKit/ITestProbe.cs index e337fbb736..bfc3c75d3e 100644 --- a/src/Proto.TestKit/ITestProbe.cs +++ b/src/Proto.TestKit/ITestProbe.cs @@ -57,6 +57,34 @@ public interface ITestProbe /// T GetNextMessage(Func when, TimeSpan? timeAllowed = null); + /// + /// asynchronously gets the next message from the test probe + /// + /// + /// + /// + Task GetNextMessageAsync(TimeSpan? timeAllowed = null, CancellationToken cancellationToken = default); + + /// + /// asynchronously gets the next message from the test probe + /// + /// + /// + /// + /// + Task GetNextMessageAsync(TimeSpan? timeAllowed = null, CancellationToken cancellationToken = default); + + /// + /// asynchronously gets the next message from the test probe + /// + /// + /// + /// + /// + /// + Task GetNextMessageAsync(Func when, TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default); + /// /// keeps returning messages until the interval between messages exceeds the time allowed /// @@ -81,6 +109,36 @@ public interface ITestProbe /// IEnumerable ProcessMessages(Func when, TimeSpan? timeAllowed = null); + /// + /// asynchronously processes messages until the interval between messages exceeds the time allowed + /// + /// + /// + /// + IAsyncEnumerable ProcessMessagesAsync(TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default); + + /// + /// asynchronously processes messages until the interval between messages exceeds the time allowed + /// + /// + /// + /// + /// + IAsyncEnumerable ProcessMessagesAsync(TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default); + + /// + /// asynchronously processes messages until the interval between messages exceeds the time allowed + /// + /// + /// + /// + /// + /// + IAsyncEnumerable ProcessMessagesAsync(Func when, TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default); + /// /// fishes for the next message of a given type from the test probe /// @@ -98,6 +156,26 @@ public interface ITestProbe /// T FishForMessage(Func when, TimeSpan? timeAllowed = null); + /// + /// asynchronously fishes for the next message of a given type from the test probe + /// + /// + /// + /// + /// + Task FishForMessageAsync(TimeSpan? timeAllowed = null, CancellationToken cancellationToken = default); + + /// + /// asynchronously fishes for the next message of a given type from the test probe + /// + /// + /// + /// + /// + /// + Task FishForMessageAsync(Func when, TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default); + /// /// sends a message from the test probe to the target /// diff --git a/src/Proto.TestKit/TestKitBase.cs b/src/Proto.TestKit/TestKitBase.cs index 89141af211..667dc5c740 100644 --- a/src/Proto.TestKit/TestKitBase.cs +++ b/src/Proto.TestKit/TestKitBase.cs @@ -64,6 +64,21 @@ public PID SpawnNamed(Props props, string name, Action? callback = nul public T GetNextMessage(Func when, TimeSpan? timeAllowed = null) => Probe.GetNextMessage(when, timeAllowed); + /// + public Task GetNextMessageAsync(TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default) => + Probe.GetNextMessageAsync(timeAllowed, cancellationToken); + + /// + public Task GetNextMessageAsync(TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default) => + Probe.GetNextMessageAsync(timeAllowed, cancellationToken); + + /// + public Task GetNextMessageAsync(Func when, TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default) => + Probe.GetNextMessageAsync(when, timeAllowed, cancellationToken); + /// public IEnumerable ProcessMessages(TimeSpan? timeAllowed = null) => Probe.ProcessMessages(timeAllowed); @@ -74,6 +89,21 @@ public T GetNextMessage(Func when, TimeSpan? timeAllowed = null) => public IEnumerable ProcessMessages(Func when, TimeSpan? timeAllowed = null) => Probe.ProcessMessages(when, timeAllowed); + /// + public IAsyncEnumerable ProcessMessagesAsync(TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default) => + Probe.ProcessMessagesAsync(timeAllowed, cancellationToken); + + /// + public IAsyncEnumerable ProcessMessagesAsync(TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default) => + Probe.ProcessMessagesAsync(timeAllowed, cancellationToken); + + /// + public IAsyncEnumerable ProcessMessagesAsync(Func when, TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default) => + Probe.ProcessMessagesAsync(when, timeAllowed, cancellationToken); + /// public T FishForMessage(TimeSpan? timeAllowed = null) => Probe.FishForMessage(timeAllowed); @@ -81,6 +111,16 @@ public IEnumerable ProcessMessages(Func when, TimeSpan? timeAllow public T FishForMessage(Func when, TimeSpan? timeAllowed = null) => Probe.FishForMessage(when, timeAllowed); + /// + public Task FishForMessageAsync(TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default) => + Probe.FishForMessageAsync(timeAllowed, cancellationToken); + + /// + public Task FishForMessageAsync(Func when, TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default) => + Probe.FishForMessageAsync(when, timeAllowed, cancellationToken); + /// public void Send(PID target, object message) => Probe.Send(target, message); diff --git a/src/Proto.TestKit/TestProbe.cs b/src/Proto.TestKit/TestProbe.cs index bf00e81360..bf096ebb52 100644 --- a/src/Proto.TestKit/TestProbe.cs +++ b/src/Proto.TestKit/TestProbe.cs @@ -6,10 +6,11 @@ using System; using System.Collections; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using System.Threading.Channels; +using System.Runtime.CompilerServices; using Proto.Mailbox; namespace Proto.TestKit; @@ -17,8 +18,7 @@ namespace Proto.TestKit; /// public class TestProbe : IActor, ITestProbe { - private readonly BlockingCollection - _messageQueue = new(); + private readonly Channel _messageChannel = Channel.CreateUnbounded(); private IContext? _context; @@ -39,12 +39,12 @@ public Task ReceiveAsync(IContext context) break; case Terminated _: - _messageQueue.Add(new MessageAndSender(context)); + _messageChannel.Writer.TryWrite(new MessageAndSender(context)); break; case SystemMessage _: return Task.CompletedTask; default: - _messageQueue.Add(new MessageAndSender(context)); + _messageChannel.Writer.TryWrite(new MessageAndSender(context)); break; } @@ -74,8 +74,15 @@ public IContext Context public void ExpectNoMessage(TimeSpan? timeAllowed = null) { var time = timeAllowed ?? TimeSpan.FromSeconds(1); + using var cts = new CancellationTokenSource(time); - if (_messageQueue.TryTake(out var o, time)) + try + { + var item = _messageChannel.Reader.ReadAsync(cts.Token).AsTask().GetAwaiter().GetResult(); + throw new TestKitException( + $"Waited {time.Seconds} seconds and received a message of type {item.Message?.GetType()}"); + } + catch (OperationCanceledException) { var seconds = time.TotalSeconds.ToString("0.###"); throw new TestKitException($"Waited {seconds} seconds and received a message of type {o.GetType()}"); @@ -86,16 +93,19 @@ public void ExpectNoMessage(TimeSpan? timeAllowed = null) public object? GetNextMessage(TimeSpan? timeAllowed = null) { var time = timeAllowed ?? TimeSpan.FromSeconds(1); + using var cts = new CancellationTokenSource(time); - if (!_messageQueue.TryTake(out var output, time)) + try + { + var output = _messageChannel.Reader.ReadAsync(cts.Token).AsTask().GetAwaiter().GetResult(); + Sender = output.Sender; + return output.Message; + } + catch (OperationCanceledException) { var seconds = time.TotalSeconds.ToString("0.###"); throw new TestKitException($"Waited {seconds} seconds but failed to receive a message"); } - - Sender = output?.Sender; - - return output?.Message; } /// @@ -124,6 +134,54 @@ public T GetNextMessage(Func when, TimeSpan? timeAllowed = null) return output; } + /// + public async Task GetNextMessageAsync(TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default) + { + var time = timeAllowed ?? TimeSpan.FromSeconds(1); + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cts.CancelAfter(time); + + try + { + var output = await _messageChannel.Reader.ReadAsync(cts.Token).ConfigureAwait(false); + Sender = output.Sender; + return output.Message; + } + catch (OperationCanceledException) + { + throw new TestKitException($"Waited {time.Seconds} seconds but failed to receive a message"); + } + } + + /// + public async Task GetNextMessageAsync(TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default) + { + var output = await GetNextMessageAsync(timeAllowed, cancellationToken).ConfigureAwait(false); + + if (output is not T typed) + { + throw new TestKitException($"Message expected type {typeof(T)}, actual type {output?.GetType()}"); + } + + return typed; + } + + /// + public async Task GetNextMessageAsync(Func when, TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default) + { + var output = await GetNextMessageAsync(timeAllowed, cancellationToken).ConfigureAwait(false); + + if (!when(output)) + { + throw new TestKitException("Condition not met"); + } + + return output; + } + /// public IEnumerable ProcessMessages(TimeSpan? timeAllowed = null) { @@ -185,7 +243,71 @@ public IEnumerable ProcessMessages(Func when, TimeSpan? timeAllow } /// - public T FishForMessage(TimeSpan? timeAllowed = null) => FishForMessage(x => true, timeAllowed); + public async IAsyncEnumerable ProcessMessagesAsync(TimeSpan? timeAllowed = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + while (true) + { + object? message; + + try + { + message = await GetNextMessageAsync(timeAllowed, cancellationToken).ConfigureAwait(false); + } + catch (TestKitException) + { + yield break; + } + + yield return message; + } + } + + /// + public async IAsyncEnumerable ProcessMessagesAsync(TimeSpan? timeAllowed = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + while (true) + { + T message; + + try + { + message = await FishForMessageAsync(timeAllowed, cancellationToken).ConfigureAwait(false); + } + catch (TestKitException) + { + yield break; + } + + yield return message; + } + } + + /// + public async IAsyncEnumerable ProcessMessagesAsync(Func when, + TimeSpan? timeAllowed = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + while (true) + { + T message; + + try + { + message = await FishForMessageAsync(when, timeAllowed, cancellationToken) + .ConfigureAwait(false); + } + catch (TestKitException) + { + yield break; + } + + yield return message; + } + } + + /// + public T FishForMessage(TimeSpan? timeAllowed = null) => FishForMessage(_ => true, timeAllowed); /// public T FishForMessage(Func when, TimeSpan? timeAllowed = null) @@ -194,12 +316,59 @@ public T FishForMessage(Func when, TimeSpan? timeAllowed = null) while (DateTime.UtcNow < endTime) { - if (_messageQueue.TryTake(out var item, endTime - DateTime.UtcNow) && - item.Message is T typed && when(typed)) + var remaining = endTime - DateTime.UtcNow; + using var cts = new CancellationTokenSource(remaining); + + try + { + var item = _messageChannel.Reader.ReadAsync(cts.Token).AsTask().GetAwaiter().GetResult(); + + if (item.Message is T typed && when(typed)) + { + Sender = item.Sender; + + return typed; + } + } + catch (OperationCanceledException) { - Sender = item.Sender; + // try again until timeout + } + } + + throw new TestKitException("Message not found"); + } + + /// + public Task FishForMessageAsync(TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default) => + FishForMessageAsync(_ => true, timeAllowed, cancellationToken); - return typed; + /// + public async Task FishForMessageAsync(Func when, TimeSpan? timeAllowed = null, + CancellationToken cancellationToken = default) + { + var endTime = DateTime.UtcNow + (timeAllowed ?? TimeSpan.FromSeconds(1)); + + while (DateTime.UtcNow < endTime) + { + var remaining = endTime - DateTime.UtcNow; + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cts.CancelAfter(remaining); + + try + { + var item = await _messageChannel.Reader.ReadAsync(cts.Token).ConfigureAwait(false); + + if (item.Message is T typed && when(typed)) + { + Sender = item.Sender; + return typed; + } + } + catch (OperationCanceledException) + { + // loop until timeout } } diff --git a/tests/Proto.TestKit.Tests/AsyncProbeTests.cs b/tests/Proto.TestKit.Tests/AsyncProbeTests.cs new file mode 100644 index 0000000000..ae810fb737 --- /dev/null +++ b/tests/Proto.TestKit.Tests/AsyncProbeTests.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using FluentAssertions; +using Xunit; + +namespace Proto.TestKit.Tests +{ + public class AsyncProbeTests : TestKitBase + { + public AsyncProbeTests() => SetUp(); + + [Fact] + public async Task GetNextMessageAsync_receives_message() + { + Send(Probe, "hello"); + var msg = await GetNextMessageAsync(); + msg.Should().Be("hello"); + } + + [Fact] + public async Task ProcessMessagesAsync_returns_all_until_timeout() + { + Send(Probe, "a"); + Send(Probe, "b"); + + var received = new List(); + await foreach (var m in ProcessMessagesAsync(TimeSpan.FromMilliseconds(100))) + { + received.Add(m); + } + + received.Should().BeEquivalentTo(new[] { "a", "b" }); + } + } +} diff --git a/tests/Proto.TestKit.Tests/MessageFiltering.cs b/tests/Proto.TestKit.Tests/MessageFiltering.cs index 6a4ea3bf1a..72f18c2393 100644 --- a/tests/Proto.TestKit.Tests/MessageFiltering.cs +++ b/tests/Proto.TestKit.Tests/MessageFiltering.cs @@ -119,6 +119,7 @@ public void ExpectNoMessageFails() var seconds = TimeSpan.FromSeconds(1).TotalSeconds.ToString("0.###"); this.Invoking(_ => ExpectNoMessage()) .Should().Throw().WithMessage($"Waited {seconds} seconds and received a message of type Proto.TestKit.MessageAndSender"); + } [Fact]