Skip to content

Commit 72d7899

Browse files
author
Meyn
committed
Fix reusing TaskCompletionSource
1 parent 62873c4 commit 72d7899

File tree

2 files changed

+39
-47
lines changed

2 files changed

+39
-47
lines changed

Requests/Request.Static.cs

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
using System.Runtime.CompilerServices;
32

43
namespace Requests
@@ -14,7 +13,7 @@ public static class Request
1413
/// <summary>
1514
/// Gets the current request context.
1615
/// </summary>
17-
internal static IRequest? Current => _current.Value;
16+
public static IRequest? Current => _current.Value;
1817

1918
/// <summary>
2019
/// Sets the current request context.
@@ -35,16 +34,26 @@ public static class Request
3534
/// <summary>
3635
/// Provides an awaitable type that enables yielding in request contexts.
3736
/// </summary>
38-
public readonly struct YieldAwaitable : ICriticalNotifyCompletion
37+
public readonly struct YieldAwaitable
3938
{
4039
private readonly IRequest? _context;
4140

4241
internal YieldAwaitable(IRequest? context) => _context = context;
4342

4443
/// <summary>
45-
/// Gets an awaiter for this awaitable (returns itself).
44+
/// Gets an awaiter for this awaitable.
4645
/// </summary>
47-
public YieldAwaitable GetAwaiter() => this;
46+
public YieldAwaiter GetAwaiter() => new(_context);
47+
}
48+
49+
/// <summary>
50+
/// Provides an awaiter that handles yielding in request contexts.
51+
/// </summary>
52+
public readonly struct YieldAwaiter : ICriticalNotifyCompletion
53+
{
54+
private readonly IRequest? _context;
55+
56+
internal YieldAwaiter(IRequest? context) => _context = context;
4857

4958
/// <summary>
5059
/// Returns true if there's no context (completes synchronously as no-op).
@@ -64,7 +73,16 @@ public void GetResult() { }
6473
public void OnCompleted(Action continuation)
6574
{
6675
ArgumentNullException.ThrowIfNull(continuation);
67-
_ = ScheduleContinuationAsync(_context!, continuation);
76+
77+
ExecutionContext? ctx = ExecutionContext.Capture();
78+
if (ctx == null)
79+
{
80+
UnsafeOnCompleted(continuation);
81+
}
82+
else
83+
{
84+
UnsafeOnCompleted(() => ExecutionContext.Run(ctx, static s => ((Action)s!)(), continuation));
85+
}
6886
}
6987

7088
/// <summary>
@@ -74,24 +92,7 @@ public void OnCompleted(Action continuation)
7492
public void UnsafeOnCompleted(Action continuation)
7593
{
7694
ArgumentNullException.ThrowIfNull(continuation);
77-
_ = ScheduleContinuationAsync(_context!, continuation);
78-
}
79-
80-
/// <summary>
81-
/// Schedules the continuation after yielding through the request context.
82-
/// </summary>
83-
private static async Task ScheduleContinuationAsync(IRequest context, Action continuation)
84-
{
85-
try
86-
{
87-
await context.YieldAsync().ConfigureAwait(false);
88-
continuation();
89-
}
90-
catch
91-
{
92-
// Still invoke continuation to prevent deadlock
93-
continuation();
94-
}
95+
_context!.YieldAsync().ConfigureAwait(false).GetAwaiter().UnsafeOnCompleted(continuation);
9596
}
9697
}
9798
}

Requests/Request.cs

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ public abstract partial class Request<TOptions, TCompleted, TFailed> : IRequest,
2020
private CancellationTokenRegistration _ctr;
2121

2222
private TaskCompletionSource<bool>? _pauseTcs;
23-
private volatile bool _hasPausedExecution;
2423

2524
private readonly RequestStateMachine _stateMachine;
2625

@@ -291,27 +290,21 @@ async Task IRequest.StartRequestAsync()
291290
return;
292291

293292
// Check if we're resuming a paused execution
294-
if (_hasPausedExecution)
293+
TaskCompletionSource<bool>? tcs = Interlocked.Exchange(ref _pauseTcs, null);
294+
if (tcs != null)
295295
{
296-
// Get and clear the pause TCS atomically
297-
TaskCompletionSource<bool>? tcs = Interlocked.Exchange(ref _pauseTcs, null);
298-
if (tcs != null)
299-
{
300-
_hasPausedExecution = false;
301-
302-
if (!_stateMachine.TryTransition(RequestState.Running))
303-
return;
296+
if (!_stateMachine.TryTransition(RequestState.Running))
297+
return;
304298

305-
_runningSourceVersion++;
306-
_runningSource.Reset();
299+
_runningSourceVersion++;
300+
_runningSource.Reset();
307301

308-
// Resume the paused execution
309-
tcs.TrySetResult(true);
302+
// Resume the paused execution
303+
tcs.TrySetResult(true);
310304

311-
// Wait for the resumed execution to stop running
312-
await new ValueTask(this, _runningSourceVersion).ConfigureAwait(false);
313-
return;
314-
}
305+
// Wait for the resumed execution to stop running
306+
await new ValueTask(this, _runningSourceVersion).ConfigureAwait(false);
307+
return;
315308
}
316309

317310
if (!_stateMachine.TryTransition(RequestState.Running))
@@ -492,10 +485,8 @@ private async ValueTask YieldAsyncSlow()
492485

493486
if (State == RequestState.Paused)
494487
{
495-
// Lazy-create pause TCS only when actually pausing
496-
TaskCompletionSource<bool> tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
497-
_pauseTcs = tcs;
498-
_hasPausedExecution = true;
488+
TaskCompletionSource<bool> tcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
489+
tcs = Interlocked.CompareExchange(ref _pauseTcs, tcs, null) ?? tcs;
499490

500491
// Wait for resume
501492
await tcs.Task.ConfigureAwait(false);
@@ -573,7 +564,7 @@ public virtual void Dispose()
573564
_requestCts?.Dispose();
574565
_ctr.Dispose();
575566
_completionSource.TrySetCanceled();
576-
_pauseTcs?.TrySetCanceled();
567+
Interlocked.Exchange(ref _pauseTcs, null)?.TrySetCanceled();
577568

578569
GC.SuppressFinalize(this);
579570
}

0 commit comments

Comments
 (0)