Skip to content

Commit 0ad8374

Browse files
authored
Avoid race that can cause Kestrel's RequestAborted to not fire (#62385)
1 parent 52453ff commit 0ad8374

File tree

5 files changed

+169
-15
lines changed

5 files changed

+169
-15
lines changed

src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -407,21 +407,20 @@ public void Reset()
407407

408408
_manuallySetRequestAbortToken = null;
409409

410-
// Lock to prevent CancelRequestAbortedToken from attempting to cancel a disposed CTS.
411-
CancellationTokenSource? localAbortCts = null;
412-
413410
lock (_abortLock)
414411
{
415412
_preventRequestAbortedCancellation = false;
416-
if (_abortedCts?.TryReset() == false)
413+
414+
// If the connection has already been aborted, allow that to be observed during the next request.
415+
if (!_connectionAborted && _abortedCts is not null)
417416
{
418-
localAbortCts = _abortedCts;
419-
_abortedCts = null;
417+
// _connectionAborted is terminal and only set inside the _abortLock, so if it isn't set here,
418+
// _abortedCts has not been canceled yet.
419+
var resetSuccess = _abortedCts.TryReset();
420+
Debug.Assert(resetSuccess);
420421
}
421422
}
422423

423-
localAbortCts?.Dispose();
424-
425424
Output?.Reset();
426425

427426
_requestHeadersParsed = 0;
@@ -760,7 +759,7 @@ private async Task ProcessRequests<TContext>(IHttpApplication<TContext> applicat
760759
}
761760
else if (!HasResponseStarted)
762761
{
763-
// If the request was aborted and no response was sent, we use status code 499 for logging
762+
// If the request was aborted and no response was sent, we use status code 499 for logging
764763
StatusCode = StatusCodes.Status499ClientClosedRequest;
765764
}
766765
}

src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ public bool ReceivedEmptyRequestBody
129129
protected override void OnReset()
130130
{
131131
_keepAlive = true;
132-
_connectionAborted = false;
133132
_userTrailers = null;
134133

135134
// Reset Http2 Features

src/Servers/Kestrel/Core/src/Internal/Http3/Http3Connection.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -602,15 +602,16 @@ private async Task CreateHttp3Stream<TContext>(ConnectionContext streamContext,
602602

603603
// Check whether there is an existing HTTP/3 stream on the transport stream.
604604
// A stream will only be cached if the transport stream itself is reused.
605-
if (!persistentStateFeature.State.TryGetValue(StreamPersistentStateKey, out var s))
605+
if (!persistentStateFeature.State.TryGetValue(StreamPersistentStateKey, out var s) ||
606+
s is not Http3Stream<TContext> { CanReuse: true } reusableStream)
606607
{
607608
stream = new Http3Stream<TContext>(application, CreateHttpStreamContext(streamContext));
608-
persistentStateFeature.State.Add(StreamPersistentStateKey, stream);
609+
persistentStateFeature.State[StreamPersistentStateKey] = stream;
609610
}
610611
else
611612
{
612-
stream = (Http3Stream<TContext>)s!;
613-
stream.InitializeWithExistingContext(streamContext.Transport);
613+
stream = reusableStream;
614+
reusableStream.InitializeWithExistingContext(streamContext.Transport);
614615
}
615616

616617
_streamLifetimeHandler.OnStreamCreated(stream);

src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ internal abstract partial class Http3Stream : HttpProtocol, IHttp3Stream, IHttpS
6666
private bool IsAbortedRead => (_completionState & StreamCompletionFlags.AbortedRead) == StreamCompletionFlags.AbortedRead;
6767
public bool IsCompleted => (_completionState & StreamCompletionFlags.Completed) == StreamCompletionFlags.Completed;
6868

69+
public bool CanReuse => !_connectionAborted && HasResponseCompleted;
70+
6971
public bool ReceivedEmptyRequestBody
7072
{
7173
get
@@ -957,7 +959,6 @@ private Task ProcessDataFrameAsync(in ReadOnlySequence<byte> payload)
957959
protected override void OnReset()
958960
{
959961
_keepAlive = true;
960-
_connectionAborted = false;
961962
_userTrailers = null;
962963
_isWebTransportSessionAccepted = false;
963964
_isMethodConnect = false;

src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,160 @@ public async Task GET_MultipleRequestsInSequence_ReusedState()
879879
}
880880
}
881881

882+
[ConditionalFact]
883+
[MsQuicSupported]
884+
public async Task GET_RequestAbortedByClient_StateNotReused()
885+
{
886+
// Arrange
887+
object persistedState = null;
888+
var requestCount = 0;
889+
var abortedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
890+
var requestStartedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
891+
892+
var builder = CreateHostBuilder(async context =>
893+
{
894+
requestCount++;
895+
var persistentStateCollection = context.Features.Get<IPersistentStateFeature>().State;
896+
if (persistentStateCollection.TryGetValue("Counter", out var value))
897+
{
898+
persistedState = value;
899+
}
900+
persistentStateCollection["Counter"] = requestCount;
901+
902+
if (requestCount == 1)
903+
{
904+
// For the first request, wait for RequestAborted to fire before returning
905+
context.RequestAborted.Register(() =>
906+
{
907+
Logger.LogInformation("Server received cancellation");
908+
abortedTcs.SetResult();
909+
});
910+
911+
// Signal that the request has started and is ready to be cancelled
912+
requestStartedTcs.SetResult();
913+
914+
// Wait for the request to be aborted
915+
await abortedTcs.Task;
916+
}
917+
});
918+
919+
using (var host = builder.Build())
920+
using (var client = HttpHelpers.CreateClient())
921+
{
922+
await host.StartAsync();
923+
924+
// Act - Send first request and cancel it
925+
var cts1 = new CancellationTokenSource();
926+
var request1 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/");
927+
request1.Version = HttpVersion.Version30;
928+
request1.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
929+
930+
var responseTask1 = client.SendAsync(request1, cts1.Token);
931+
932+
// Wait for the server to start processing the request
933+
await requestStartedTcs.Task.DefaultTimeout();
934+
935+
// Cancel the first request
936+
cts1.Cancel();
937+
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => responseTask1).DefaultTimeout();
938+
939+
// Wait for the server to process the abort
940+
await abortedTcs.Task.DefaultTimeout();
941+
942+
// Store the state from the first (aborted) request
943+
var firstRequestState = persistedState;
944+
945+
// Delay to ensure the stream has enough time to return to pool
946+
await Task.Delay(100);
947+
948+
// Send second request (should not reuse state from aborted request)
949+
var request2 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/");
950+
request2.Version = HttpVersion.Version30;
951+
request2.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
952+
953+
var response2 = await client.SendAsync(request2, CancellationToken.None);
954+
response2.EnsureSuccessStatusCode();
955+
var secondRequestState = persistedState;
956+
957+
// Assert
958+
// First request has no persisted state (it was aborted)
959+
Assert.Null(firstRequestState);
960+
961+
// Second request should also have no persisted state since the first request was aborted
962+
// and state should not be reused from aborted requests
963+
Assert.Null(secondRequestState);
964+
965+
await host.StopAsync();
966+
}
967+
}
968+
969+
[ConditionalFact]
970+
[MsQuicSupported]
971+
public async Task GET_RequestAbortedByServer_StateNotReused()
972+
{
973+
// Arrange
974+
object persistedState = null;
975+
var requestCount = 0;
976+
977+
var builder = CreateHostBuilder(context =>
978+
{
979+
requestCount++;
980+
var persistentStateCollection = context.Features.Get<IPersistentStateFeature>().State;
981+
if (persistentStateCollection.TryGetValue("Counter", out var value))
982+
{
983+
persistedState = value;
984+
}
985+
persistentStateCollection["Counter"] = requestCount;
986+
987+
if (requestCount == 1)
988+
{
989+
context.Abort();
990+
}
991+
992+
return Task.CompletedTask;
993+
});
994+
995+
using (var host = builder.Build())
996+
using (var client = HttpHelpers.CreateClient())
997+
{
998+
await host.StartAsync();
999+
1000+
var request1 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/");
1001+
request1.Version = HttpVersion.Version30;
1002+
request1.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
1003+
1004+
var responseTask1 = client.SendAsync(request1, CancellationToken.None);
1005+
var ex = await Assert.ThrowsAnyAsync<HttpRequestException>(() => responseTask1).DefaultTimeout();
1006+
var innerEx = Assert.IsType<HttpProtocolException>(ex.InnerException);
1007+
Assert.Equal(Http3ErrorCode.InternalError, (Http3ErrorCode)innerEx.ErrorCode);
1008+
1009+
// Store the state from the first (aborted) request
1010+
var firstRequestState = persistedState;
1011+
1012+
// Delay to ensure the stream has enough time to return to pool
1013+
await Task.Delay(100);
1014+
1015+
// Send second request (should not reuse state from aborted request)
1016+
var request2 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/");
1017+
request2.Version = HttpVersion.Version30;
1018+
request2.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
1019+
1020+
var response2 = await client.SendAsync(request2, CancellationToken.None);
1021+
response2.EnsureSuccessStatusCode();
1022+
var secondRequestState = persistedState;
1023+
1024+
// Assert
1025+
// First request has no persisted state (it was aborted)
1026+
Assert.Null(firstRequestState);
1027+
1028+
// Second request should also have no persisted state since the first request was aborted
1029+
// and state should not be reused from aborted requests
1030+
Assert.Null(secondRequestState);
1031+
1032+
await host.StopAsync();
1033+
}
1034+
}
1035+
8821036
[ConditionalFact]
8831037
[MsQuicSupported]
8841038
public async Task GET_MultipleRequests_RequestVersionOrHigher_UpgradeToHttp3()

0 commit comments

Comments
 (0)