Skip to content

Commit e07fc54

Browse files
author
Meyn
committed
Implement IAsyncEnumerable in RequestHandler
1 parent f103b0b commit e07fc54

File tree

1 file changed

+171
-32
lines changed

1 file changed

+171
-32
lines changed

Requests/RequestHandler.cs

Lines changed: 171 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace Requests
99
/// The <see cref="RequestHandler"/> class is responsible for executing instances of the <see cref="IRequest"/> interface.
1010
/// Optimized for high-performance scenarios with minimal allocations and thread-safe state management.
1111
/// </summary>
12-
public partial class RequestHandler : IRequestContainer<IRequest>
12+
public class RequestHandler : IRequestContainer<IRequest>, IAsyncEnumerable<IRequest>
1313
{
1414
private readonly IPriorityChannel<IRequest> _requestsChannel;
1515
private readonly RequestContainerStateMachine _stateMachine;
@@ -20,6 +20,7 @@ public partial class RequestHandler : IRequestContainer<IRequest>
2020
private CancellationTokenSource _cts = new();
2121
private readonly PauseTokenSource _pts = new();
2222
private Task? _task;
23+
private Exception? _unhandledException;
2324

2425
// Cached delegate to avoid allocations
2526
private static readonly SendOrPostCallback s_stateChangedCallback = static state =>
@@ -31,13 +32,18 @@ public partial class RequestHandler : IRequestContainer<IRequest>
3132
/// <summary>
3233
/// Represents the current state of this <see cref="RequestHandler"/>.
3334
/// </summary>
34-
public RequestState State { get => _stateMachine.Current; private set => _stateMachine.TryTransition(value); }
35+
public RequestState State => _stateMachine.Current;
3536

3637
/// <summary>
3738
/// Event triggered when the <see cref="State"/> of this object changes.
3839
/// </summary>
3940
public event EventHandler<RequestState>? StateChanged;
4041

42+
/// <summary>
43+
/// Event triggered when an unhandled exception occurs in the handler.
44+
/// </summary>
45+
public event EventHandler<Exception>? UnhandledException;
46+
4147
/// <summary>
4248
/// The priority of this request handler.
4349
/// </summary>
@@ -50,9 +56,9 @@ public partial class RequestHandler : IRequestContainer<IRequest>
5056

5157
/// <summary>
5258
/// Gets the aggregate exception associated with the <see cref="RequestHandler"/> instance.
53-
/// Currently, this property always returns <c>null</c>, indicating that no exceptions are associated with the handler.
59+
/// Returns the last unhandled exception if any occurred.
5460
/// </summary>
55-
public AggregateException? Exception => null;
61+
public AggregateException? Exception => _unhandledException != null ? new AggregateException(_unhandledException) : null;
5662

5763
/// <summary>
5864
/// Property that sets the degree of parallel execution of instances of the <see cref="IRequest"/> interface.
@@ -107,13 +113,55 @@ public int MaxParallelism
107113
public int Count => _requestsChannel.Count;
108114

109115
/// <summary>
110-
/// Represents a task that completes when all the requests currently present in the handler have finished processing.
111-
/// This task does not account for any requests that may be added to the handler after its creation.
116+
/// Asynchronously enumerates all currently pending requests in the handler.
117+
/// </summary>
118+
/// <param name="cancellationToken">Cancellation token to stop enumeration.</param>
119+
/// <returns>An async enumerator of pending requests.</returns>
120+
/// <remarks>
121+
/// <strong>Warning:</strong> This operation may block the handler for a period of time.
122+
/// </remarks>
123+
public async IAsyncEnumerator<IRequest> GetAsyncEnumerator(CancellationToken cancellationToken = default)
124+
{
125+
// Take a snapshot to avoid blocking
126+
PriorityItem<IRequest>[] snapshot;
127+
try
128+
{
129+
snapshot = _requestsChannel.ToArray();
130+
}
131+
catch
132+
{
133+
yield break;
134+
}
135+
136+
foreach (PriorityItem<IRequest> item in snapshot)
137+
{
138+
cancellationToken.ThrowIfCancellationRequested();
139+
yield return item.Item;
140+
await Task.Yield();
141+
}
142+
}
143+
144+
/// <summary>
145+
/// Waits for all currently pending requests to complete.
146+
/// Equivalent to awaiting all requests from 'await foreach (var request in handler)'.
112147
/// </summary>
148+
/// <param name="cancellationToken">Cancellation token.</param>
149+
/// <returns>Task that completes when all current requests are done.</returns>
113150
/// <remarks>
114151
/// <strong>Warning:</strong> This operation may block the handler for a period of time.
115152
/// </remarks>
116-
public Task CurrentTask => Task.WhenAll(_requestsChannel.ToArray().Select(requestPair => requestPair.Item.Task));
153+
public async Task WaitForCurrentRequestsAsync(CancellationToken cancellationToken = default)
154+
{
155+
List<Task> tasks = [];
156+
157+
await foreach (IRequest request in this.WithCancellation(cancellationToken).ConfigureAwait(false))
158+
{
159+
tasks.Add(request.Task);
160+
}
161+
162+
if (tasks.Count > 0)
163+
await Task.WhenAll(tasks).ConfigureAwait(false);
164+
}
117165

118166
/// <summary>
119167
/// Specifies a request that should be executed immediately after this request completes, bypassing the queue.
@@ -148,19 +196,64 @@ private void OnStateChanged(RequestState oldState, RequestState newState)
148196
=> DefaultSynchronizationContext.Post(s_stateChangedCallback, (this, newState));
149197

150198
/// <summary>
151-
/// Method to add a single instance of the <see cref="IRequest"/> interface to the handler.
199+
/// Attempts to transition to a new state.
200+
/// Throws if transition is invalid.
201+
/// </summary>
202+
private void SetState(RequestState newState)
203+
{
204+
if (!_stateMachine.TryTransition(newState))
205+
{
206+
throw new InvalidOperationException(
207+
$"Invalid state transition from {State} to {newState}");
208+
}
209+
}
210+
211+
/// <summary>
212+
/// Handles unhandled exceptions from the handler's execution.
213+
/// </summary>
214+
private void OnUnhandledExceptionOccurred(Exception ex)
215+
{
216+
_unhandledException = ex;
217+
DefaultSynchronizationContext.Post(static state =>
218+
{
219+
(RequestHandler handler, Exception exception) = ((RequestHandler, Exception))state!;
220+
handler.UnhandledException?.Invoke(handler, exception);
221+
}, (this, ex));
222+
}
223+
224+
/// <summary>
225+
/// Synchronously adds a request to the handler.
226+
/// Throws if the channel is closed or the request is null.
152227
/// </summary>
153228
/// <param name="request">The instance of the <see cref="IRequest"/> interface that should be added.</param>
154-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
229+
/// <exception cref="ArgumentNullException">Thrown if request is null.</exception>
230+
/// <exception cref="InvalidOperationException">Thrown if the channel is closed.</exception>
155231
public void Add(IRequest request)
156-
=> _ = _requestsChannel.Writer.WriteAsync(new(request.Priority, request)).AsTask();
232+
{
233+
ArgumentNullException.ThrowIfNull(request);
234+
235+
if (!_requestsChannel.Writer.TryWrite(new(request.Priority, request)))
236+
throw new InvalidOperationException("Failed to add request, channel may be closed or full");
237+
}
238+
239+
/// <summary>
240+
/// Asynchronously adds a request to the handler.
241+
/// </summary>
242+
/// <param name="request">The request to add.</param>
243+
/// <param name="cancellationToken">Cancellation token.</param>
244+
public async ValueTask AddAsync(IRequest request, CancellationToken cancellationToken = default)
245+
{
246+
ArgumentNullException.ThrowIfNull(request);
247+
await _requestsChannel.Writer.WriteAsync(new(request.Priority, request), cancellationToken).ConfigureAwait(false);
248+
}
157249

158250
/// <summary>
159251
/// Method to add multiple instances of the <see cref="IRequest"/> interface to the handler.
160252
/// </summary>
161253
/// <param name="requests">The instances of the <see cref="IRequest"/> interface that should be added.</param>
162254
public void AddRange(params IRequest[] requests)
163255
{
256+
ArgumentNullException.ThrowIfNull(requests);
164257
foreach (IRequest request in requests)
165258
Add(request);
166259
}
@@ -190,13 +283,15 @@ public void RunRequests(params IRequest[] requests)
190283
/// </summary>
191284
public void Start()
192285
{
193-
if (!_requestsChannel.Options.EasyEndToken.IsPaused)
286+
if (!_pts.IsPaused)
287+
return;
288+
289+
if (!_stateMachine.TryTransition(RequestState.Idle))
194290
return;
195291

196-
State = RequestState.Idle;
197292
_pts.Resume();
198293

199-
if (Count > 0)
294+
if (_requestsChannel.Reader.Count > 0)
200295
RunRequests();
201296
}
202297

@@ -206,7 +301,7 @@ public void Start()
206301
public void Pause()
207302
{
208303
_pts.Pause();
209-
State = RequestState.Paused;
304+
_stateMachine.TryTransition(RequestState.Paused);
210305
}
211306

212307
/// <summary>
@@ -223,6 +318,7 @@ public void CreateCTS()
223318
_cts = new CancellationTokenSource();
224319
_requestsChannel.Options.CancellationToken = CancellationToken;
225320

321+
_stateMachine.TryTransition(RequestState.Idle);
226322
if (Count > 0)
227323
RunRequests();
228324
}
@@ -243,7 +339,7 @@ public static void CreateMainCTS()
243339
public void Cancel()
244340
{
245341
_cts.Cancel();
246-
State = RequestState.Cancelled;
342+
_stateMachine.TryTransition(RequestState.Cancelled);
247343
}
248344

249345
/// <summary>
@@ -282,7 +378,17 @@ public void RunRequests()
282378
if (State != RequestState.Idle)
283379
return;
284380

285-
_ = Task.Run(async () => await ((IRequest)this).StartRequestAsync().ConfigureAwait(false));
381+
_task = Task.Run(async () =>
382+
{
383+
try
384+
{
385+
await ((IRequest)this).StartRequestAsync().ConfigureAwait(false);
386+
}
387+
catch (Exception ex)
388+
{
389+
OnUnhandledExceptionOccurred(ex);
390+
}
391+
});
286392
}
287393

288394
/// <summary>
@@ -294,8 +400,7 @@ async Task IRequest.StartRequestAsync()
294400
if (State != RequestState.Idle || CancellationToken.IsCancellationRequested || _pts.IsPaused)
295401
return;
296402

297-
_task = RunChannelAsync();
298-
await Task.ConfigureAwait(false);
403+
await RunChannelAsync().ConfigureAwait(false);
299404
}
300405

301406
/// <summary>
@@ -304,12 +409,19 @@ async Task IRequest.StartRequestAsync()
304409
/// <returns>async Task to await</returns>
305410
private async Task RunChannelAsync()
306411
{
307-
State = RequestState.Running;
412+
SetState(RequestState.Running);
308413
UpdateAutoParallelism();
309414

310-
await _requestsChannel.RunParallelReader(async (pair, ct) => await HandleRequestAsync(pair).ConfigureAwait(false)).ConfigureAwait(false);
311-
312-
State = RequestState.Idle;
415+
try
416+
{
417+
await _requestsChannel.RunParallelReader(async (pair, ct) =>
418+
await HandleRequestAsync(pair).ConfigureAwait(false))
419+
.ConfigureAwait(false);
420+
}
421+
finally
422+
{
423+
_stateMachine.TryTransition(RequestState.Idle);
424+
}
313425

314426
if (_requestsChannel.Reader.Count > 0)
315427
await ((IRequest)this).StartRequestAsync().ConfigureAwait(false);
@@ -333,7 +445,7 @@ private async Task HandleRequestAsync(PriorityItem<IRequest> pair)
333445
}
334446
else if (request.State == RequestState.Idle)
335447
{
336-
await _requestsChannel.Writer.WriteAsync(pair).ConfigureAwait(false);
448+
await _requestsChannel.Writer.WriteAsync(pair, CancellationToken).ConfigureAwait(false);
337449
}
338450
}
339451

@@ -377,25 +489,47 @@ public void UpdateAutoParallelism()
377489

378490
/// <summary>
379491
/// Attempts to set all <see cref="IRequest"/> objects in the container's <see cref="State"/> to idle.
380-
/// No new requests will be started or read while processing. And the running requests will be paused.
492+
/// Pauses the handler during this operation and returns it to the previous state afterward.
381493
/// </summary>
382494
/// <returns>True if all <see cref="IRequest"/> objects are in an idle <see cref="RequestState"/>, otherwise false.</returns>
383495
public bool TrySetIdle()
384496
{
385-
Pause();
386-
PriorityItem<IRequest>[] requests = _requestsChannel.ToArray();
497+
RequestState previousState = State;
498+
499+
if (!_stateMachine.TryTransition(RequestState.Paused))
500+
return false;
387501

388-
foreach (PriorityItem<IRequest> priorityItem in requests)
389-
_ = priorityItem.Item.TrySetIdle();
502+
try
503+
{
504+
PriorityItem<IRequest>[] requests = _requestsChannel.ToArray();
505+
506+
foreach (PriorityItem<IRequest> priorityItem in requests)
507+
_ = priorityItem.Item.TrySetIdle();
508+
509+
bool allIdle = requests.All(x => x.Item.State == RequestState.Idle);
510+
511+
// Restore previous state if we paused it
512+
if (previousState != RequestState.Paused)
513+
_stateMachine.TryTransition(previousState);
390514

391-
return requests.All(x => x.Item.State == RequestState.Idle);
515+
return allIdle;
516+
}
517+
catch
518+
{
519+
// Restore state on error
520+
if (previousState != RequestState.Paused)
521+
_stateMachine.TryTransition(previousState);
522+
throw;
523+
}
392524
}
393525

394526
/// <summary>
395-
/// Checks whether the <see cref="RequestHandler"/> has reached a final state and will process <see cref="IRequest"/> objects.
527+
/// Checks whether the <see cref="RequestHandler"/> has completed all work.
396528
/// </summary>
397-
/// <returns><c>true</c> if the handler is in a final state; otherwise, <c>false</c>.</returns>
398-
public bool HasCompleted() => _requestsChannel.Reader.Completion.IsCompleted;
529+
/// <returns><c>true</c> if the handler is in a terminal state and has no pending requests; otherwise, <c>false</c>.</returns>
530+
public bool HasCompleted() =>
531+
State is RequestState.Completed or RequestState.Cancelled
532+
&& _requestsChannel.Reader.Count == 0;
399533

400534
/// <summary>
401535
/// Yield point for IRequest interface compatibility.
@@ -449,6 +583,8 @@ public override string ToString()
449583
/// Attempts to remove the specified requests from the priority channel.
450584
/// </summary>
451585
/// <param name="requests">The requests to remove.</param>
586+
/// <exception cref="ArgumentNullException">Thrown if requests is null or empty.</exception>
587+
/// <exception cref="InvalidOperationException">Thrown if a request cannot be removed.</exception>
452588
/// <remarks>
453589
/// <strong>Warning:</strong> This method can produce a significant amount of overhead, especially when dealing with a large number of requests.
454590
/// </remarks>
@@ -459,6 +595,9 @@ public void Remove(params IRequest[] requests)
459595

460596
foreach (IRequest request in requests)
461597
{
598+
if (request == null)
599+
throw new ArgumentNullException(nameof(requests), "Individual request cannot be null.");
600+
462601
if (!_requestsChannel.TryRemove(new(request.Priority, request)))
463602
throw new InvalidOperationException($"Failed to remove request: {request}");
464603
}

0 commit comments

Comments
 (0)