Skip to content

Commit 0826b7e

Browse files
committed
Add UniTask.WhenEach
1 parent 87e164e commit 0826b7e

File tree

8 files changed

+310
-21
lines changed

8 files changed

+310
-21
lines changed

README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ UniTask provides three pattern of extension methods.
160160
161161
> Note: AssetBundleRequest has `asset` and `allAssets`, default await returns `asset`. If you want to get `allAssets`, you can use `AwaitForAllAssets()` method.
162162
163-
The type of `UniTask` can use utilities like `UniTask.WhenAll`, `UniTask.WhenAny`. They are like `Task.WhenAll`/`Task.WhenAny` but the return type is more useful. They return value tuples so you can deconstruct each result and pass multiple types.
163+
The type of `UniTask` can use utilities like `UniTask.WhenAll`, `UniTask.WhenAny`, `UniTask.WhenEach`. They are like `Task.WhenAll`/`Task.WhenAny` but the return type is more useful. They return value tuples so you can deconstruct each result and pass multiple types.
164164

165165
```csharp
166166
public async UniTaskVoid LoadManyAsync()
@@ -716,6 +716,19 @@ await UniTaskAsyncEnumerable.EveryUpdate().ForEachAsync(_ =>
716716
}, token);
717717
```
718718

719+
`UniTask.WhenEach` that is similar to .NET 9's `Task.WhenEach` can consume new way for await multiple tasks.
720+
721+
```csharp
722+
await foreach (var result in UniTask.WhenEach(task1, task2, task3))
723+
{
724+
// The result is of type WhenEachResult<T>.
725+
// It contains either `T Result` or `Exception Exception`.
726+
// You can check `IsCompletedSuccessfully` or `IsFaulted` to determine whether to access `.Result` or `.Exception`.
727+
// If you want to throw an exception when `IsFaulted` and retrieve the result when successful, use `GetResult()`.
728+
Debug.Log(result.GetResult());
729+
}
730+
```
731+
719732
UniTaskAsyncEnumerable implements asynchronous LINQ, similar to LINQ in `IEnumerable<T>` or Rx in `IObservable<T>`. All standard LINQ query operators can be applied to asynchronous streams. For example, the following code shows how to apply a Where filter to a button-click asynchronous stream that runs once every two clicks.
720733

721734
```csharp
@@ -1026,6 +1039,7 @@ Use UniTask type.
10261039
| `Task.Run` | `UniTask.RunOnThreadPool` |
10271040
| `Task.WhenAll` | `UniTask.WhenAll` |
10281041
| `Task.WhenAny` | `UniTask.WhenAny` |
1042+
| `Task.WhenEach` | `UniTask.WhenEach` |
10291043
| `Task.CompletedTask` | `UniTask.CompletedTask` |
10301044
| `Task.FromException` | `UniTask.FromException` |
10311045
| `Task.FromResult` | `UniTask.FromResult` |
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
using Cysharp.Threading.Tasks;
2+
using FluentAssertions;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Linq;
6+
using System.Text;
7+
using System.Threading.Tasks;
8+
using Xunit;
9+
10+
namespace NetCoreTests
11+
{
12+
public class WhenEachTest
13+
{
14+
[Fact]
15+
public async Task Each()
16+
{
17+
var a = Delay(1, 3000);
18+
var b = Delay(2, 1000);
19+
var c = Delay(3, 2000);
20+
21+
var l = new List<int>();
22+
await foreach (var item in UniTask.WhenEach(a, b, c))
23+
{
24+
l.Add(item.Result);
25+
}
26+
27+
l.Should().Equal(2, 3, 1);
28+
}
29+
30+
[Fact]
31+
public async Task Error()
32+
{
33+
var a = Delay2(1, 3000);
34+
var b = Delay2(2, 1000);
35+
var c = Delay2(3, 2000);
36+
37+
var l = new List<WhenEachResult<int>>();
38+
await foreach (var item in UniTask.WhenEach(a, b, c))
39+
{
40+
l.Add(item);
41+
}
42+
43+
l[0].IsCompletedSuccessfully.Should().BeTrue();
44+
l[0].IsFaulted.Should().BeFalse();
45+
l[0].Result.Should().Be(2);
46+
47+
l[1].IsCompletedSuccessfully.Should().BeFalse();
48+
l[1].IsFaulted.Should().BeTrue();
49+
l[1].Exception.Message.Should().Be("ERROR");
50+
51+
l[2].IsCompletedSuccessfully.Should().BeTrue();
52+
l[2].IsFaulted.Should().BeFalse();
53+
l[2].Result.Should().Be(1);
54+
}
55+
56+
async UniTask<int> Delay(int id, int sleep)
57+
{
58+
await Task.Delay(sleep);
59+
return id;
60+
}
61+
62+
async UniTask<int> Delay2(int id, int sleep)
63+
{
64+
await Task.Delay(sleep);
65+
if (id == 3) throw new Exception("ERROR");
66+
return id;
67+
}
68+
}
69+
}
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
using Cysharp.Threading.Tasks.Internal;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Runtime.ExceptionServices;
5+
using System.Threading;
6+
7+
namespace Cysharp.Threading.Tasks
8+
{
9+
public partial struct UniTask
10+
{
11+
public static IUniTaskAsyncEnumerable<WhenEachResult<T>> WhenEach<T>(IEnumerable<UniTask<T>> tasks)
12+
{
13+
return new WhenEachEnumerable<T>(tasks);
14+
}
15+
16+
public static IUniTaskAsyncEnumerable<WhenEachResult<T>> WhenEach<T>(params UniTask<T>[] tasks)
17+
{
18+
return new WhenEachEnumerable<T>(tasks);
19+
}
20+
}
21+
22+
public readonly struct WhenEachResult<T>
23+
{
24+
public T Result { get; }
25+
public Exception Exception { get; }
26+
27+
//[MemberNotNullWhen(false, nameof(Exception))]
28+
public bool IsCompletedSuccessfully => Exception == null;
29+
30+
//[MemberNotNullWhen(true, nameof(Exception))]
31+
public bool IsFaulted => Exception != null;
32+
33+
public WhenEachResult(T result)
34+
{
35+
this.Result = result;
36+
this.Exception = null;
37+
}
38+
39+
public WhenEachResult(Exception exception)
40+
{
41+
if (exception == null) throw new ArgumentNullException(nameof(exception));
42+
this.Result = default!;
43+
this.Exception = exception;
44+
}
45+
46+
public void TryThrow()
47+
{
48+
if (IsFaulted)
49+
{
50+
ExceptionDispatchInfo.Capture(Exception).Throw();
51+
}
52+
}
53+
54+
public T GetResult()
55+
{
56+
if (IsFaulted)
57+
{
58+
ExceptionDispatchInfo.Capture(Exception).Throw();
59+
}
60+
return Result;
61+
}
62+
63+
public override string ToString()
64+
{
65+
if (IsCompletedSuccessfully)
66+
{
67+
return Result?.ToString() ?? "";
68+
}
69+
else
70+
{
71+
return $"Exception{{{Exception.Message}}}";
72+
}
73+
}
74+
}
75+
76+
internal enum WhenEachState : byte
77+
{
78+
NotRunning,
79+
Running,
80+
Completed
81+
}
82+
83+
internal sealed class WhenEachEnumerable<T> : IUniTaskAsyncEnumerable<WhenEachResult<T>>
84+
{
85+
IEnumerable<UniTask<T>> source;
86+
87+
public WhenEachEnumerable(IEnumerable<UniTask<T>> source)
88+
{
89+
this.source = source;
90+
}
91+
92+
public IUniTaskAsyncEnumerator<WhenEachResult<T>> GetAsyncEnumerator(CancellationToken cancellationToken = default)
93+
{
94+
return new Enumerator(source, cancellationToken);
95+
}
96+
97+
sealed class Enumerator : IUniTaskAsyncEnumerator<WhenEachResult<T>>
98+
{
99+
readonly IEnumerable<UniTask<T>> source;
100+
CancellationToken cancellationToken;
101+
102+
Channel<WhenEachResult<T>> channel;
103+
IUniTaskAsyncEnumerator<WhenEachResult<T>> channelEnumerator;
104+
int completeCount;
105+
WhenEachState state;
106+
107+
public Enumerator(IEnumerable<UniTask<T>> source, CancellationToken cancellationToken)
108+
{
109+
this.source = source;
110+
this.cancellationToken = cancellationToken;
111+
}
112+
113+
public WhenEachResult<T> Current => channelEnumerator.Current;
114+
115+
public UniTask<bool> MoveNextAsync()
116+
{
117+
cancellationToken.ThrowIfCancellationRequested();
118+
119+
if (state == WhenEachState.NotRunning)
120+
{
121+
state = WhenEachState.Running;
122+
channel = Channel.CreateSingleConsumerUnbounded<WhenEachResult<T>>();
123+
channelEnumerator = channel.Reader.ReadAllAsync().GetAsyncEnumerator(cancellationToken);
124+
125+
if (source is UniTask<T>[] array)
126+
{
127+
ConsumeAll(this, array, array.Length);
128+
}
129+
else
130+
{
131+
using (var rentArray = ArrayPoolUtil.Materialize(source))
132+
{
133+
ConsumeAll(this, rentArray.Array, rentArray.Length);
134+
}
135+
}
136+
}
137+
138+
return channelEnumerator.MoveNextAsync();
139+
}
140+
141+
static void ConsumeAll(Enumerator self, UniTask<T>[] array, int length)
142+
{
143+
for (int i = 0; i < length; i++)
144+
{
145+
RunWhenEachTask(self, array[i], length).Forget();
146+
}
147+
148+
static async UniTaskVoid RunWhenEachTask(Enumerator self, UniTask<T> task, int length)
149+
{
150+
try
151+
{
152+
var result = await task;
153+
self.channel.Writer.TryWrite(new WhenEachResult<T>(result));
154+
}
155+
catch (Exception ex)
156+
{
157+
self.channel.Writer.TryWrite(new WhenEachResult<T>(ex));
158+
}
159+
160+
if (Interlocked.Increment(ref self.completeCount) == length)
161+
{
162+
self.state = WhenEachState.Completed;
163+
self.channel.Writer.TryComplete();
164+
}
165+
}
166+
}
167+
168+
public async UniTask DisposeAsync()
169+
{
170+
if (channelEnumerator != null)
171+
{
172+
await channelEnumerator.DisposeAsync();
173+
}
174+
175+
if (state != WhenEachState.Completed)
176+
{
177+
state = WhenEachState.Completed;
178+
channel.Writer.TryComplete(new OperationCanceledException());
179+
}
180+
}
181+
}
182+
}
183+
}

src/UniTask/Assets/Plugins/UniTask/Runtime/UniTask.WhenEach.cs.meta

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/UniTask/Assets/Scenes/SandboxMain.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,28 @@ public void Dispose()
119119
connection.Dispose();
120120
}
121121
}
122+
public class WhenEachTest
123+
{
124+
public async UniTask Each()
125+
{
126+
var a = Delay(1, 3000);
127+
var b = Delay(2, 1000);
128+
var c = Delay(3, 2000);
129+
130+
var l = new List<int>();
131+
await foreach (var item in UniTask.WhenEach(a, b, c))
132+
{
133+
Debug.Log(item.Result);
134+
}
135+
}
122136

137+
async UniTask<int> Delay(int id, int sleep)
138+
{
139+
await UniTask.Delay(sleep);
140+
return id;
141+
}
142+
143+
}
123144

124145
public class SandboxMain : MonoBehaviour
125146
{
@@ -147,6 +168,18 @@ async UniTask<int> FooAsync()
147168

148169
Debug.Log("Again");
149170

171+
172+
// var foo = InstantiateAsync<SandboxMain>(this).ToUniTask();
173+
174+
175+
176+
177+
178+
// var tako = await foo;
179+
180+
181+
182+
150183
return 10;
151184
}
152185

@@ -557,6 +590,7 @@ private static async UniTask TestAsync(CancellationToken ct)
557590

558591
async UniTaskVoid Start()
559592
{
593+
await new WhenEachTest().Each();
560594

561595

562596
// UniTask.Delay(TimeSpan.FromSeconds(1)).TimeoutWithoutException

0 commit comments

Comments
 (0)