Skip to content

Commit 1a8f5a8

Browse files
authored
More request validation out of the AsyncAcceptContext (#28452)
* More request validation out of the AsyncAcceptContext - Moved request validation logic to the accept loop. This simplifies the logic inside of the accept context. - This also moves feature initializaton to after the threadpool dispatch. - This does make make the async state machine yield when requests exceeed the header limit or if auth fails. * PR feedback
1 parent 1173fbd commit 1a8f5a8

File tree

6 files changed

+34
-45
lines changed

6 files changed

+34
-45
lines changed

src/Servers/HttpSys/src/AsyncAcceptContext.cs

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System.Threading;
66
using System.Threading.Tasks;
77
using System.Threading.Tasks.Sources;
8-
using Microsoft.AspNetCore.Http;
98
using Microsoft.AspNetCore.HttpSys.Internal;
109

1110
namespace Microsoft.AspNetCore.Server.HttpSys
@@ -14,6 +13,8 @@ internal unsafe class AsyncAcceptContext : IValueTaskSource<RequestContext>, IDi
1413
{
1514
private static readonly IOCompletionCallback IOCallback = IOWaitCallback;
1615
private readonly PreAllocatedOverlapped _preallocatedOverlapped;
16+
private readonly IRequestContextFactory _requestContextFactory;
17+
1718
private NativeOverlapped* _overlapped;
1819

1920
// mutable struct; do not make this readonly
@@ -24,7 +25,6 @@ internal unsafe class AsyncAcceptContext : IValueTaskSource<RequestContext>, IDi
2425
};
2526

2627
private RequestContext _requestContext;
27-
private readonly IRequestContextFactory _requestContextFactory;
2828

2929
internal AsyncAcceptContext(HttpSysListener server, IRequestContextFactory requestContextFactory)
3030
{
@@ -55,7 +55,6 @@ internal ValueTask<RequestContext> AcceptAsync()
5555

5656
private static void IOCompleted(AsyncAcceptContext asyncContext, uint errorCode, uint numBytes)
5757
{
58-
var complete = false;
5958
// This is important to stash a ref to as it's a mutable struct
6059
ref var mrvts = ref asyncContext._mrvts;
6160
var requestContext = asyncContext._requestContext;
@@ -73,48 +72,18 @@ private static void IOCompleted(AsyncAcceptContext asyncContext, uint errorCode,
7372
HttpSysListener server = asyncContext.Server;
7473
if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS)
7574
{
76-
// at this point we have received an unmanaged HTTP_REQUEST and memoryBlob
77-
// points to it we need to hook up our authentication handling code here.
78-
try
79-
{
80-
if (server.ValidateRequest(requestContext) && server.ValidateAuth(requestContext))
81-
{
82-
// It's important that we clear the request context before we set the result
83-
// we want to reuse the acceptContext object for future accepts.
84-
asyncContext._requestContext = null;
75+
// It's important that we clear the request context before we set the result
76+
// we want to reuse the acceptContext object for future accepts.
77+
asyncContext._requestContext = null;
8578

86-
// Initialize features here once we're successfully validated the request
87-
// TODO: In the future defer this work to the thread pool so we can get off the IO thread
88-
// as quickly as possible
89-
requestContext.InitializeFeatures();
90-
91-
mrvts.SetResult(requestContext);
92-
93-
complete = true;
94-
}
95-
}
96-
catch (Exception ex)
97-
{
98-
server.SendError(requestId, StatusCodes.Status400BadRequest);
99-
mrvts.SetException(ex);
100-
}
101-
finally
102-
{
103-
if (!complete)
104-
{
105-
asyncContext.AllocateNativeRequest(size: requestContext.Size);
106-
}
107-
}
79+
mrvts.SetResult(requestContext);
10880
}
10981
else
11082
{
11183
// (uint)backingBuffer.Length - AlignmentPadding
11284
asyncContext.AllocateNativeRequest(numBytes, requestId);
113-
}
11485

115-
// We need to issue a new request, either because auth failed, or because our buffer was too small the first time.
116-
if (!complete)
117-
{
86+
// We need to issue a new request, either because auth failed, or because our buffer was too small the first time.
11887
uint statusCode = asyncContext.QueueBeginGetContext();
11988

12089
if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS &&

src/Servers/HttpSys/src/HttpSysListener.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,17 +293,14 @@ internal unsafe bool ValidateRequest(NativeRequestContext requestMemory)
293293
SendError(requestMemory.RequestId, StatusCodes.Status400BadRequest, authChallenges: null);
294294
return false;
295295
}
296-
return true;
297-
}
298296

299-
internal unsafe bool ValidateAuth(NativeRequestContext requestMemory)
300-
{
301297
if (!Options.Authentication.AllowAnonymous && !requestMemory.CheckAuthenticated())
302298
{
303299
SendError(requestMemory.RequestId, StatusCodes.Status401Unauthorized,
304300
AuthenticationManager.GenerateChallenges(Options.Authentication.Schemes));
305301
return false;
306302
}
303+
307304
return true;
308305
}
309306

src/Servers/HttpSys/src/MessagePump.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ private async Task ProcessRequestsWorker()
189189
try
190190
{
191191
requestContext = await Listener.AcceptAsync(acceptContext);
192+
193+
if (!Listener.ValidateRequest(requestContext))
194+
{
195+
// If either of these is false then a response has already been sent to the client, so we can accept the next request
196+
continue;
197+
}
192198
}
193199
catch (Exception exception)
194200
{

src/Servers/HttpSys/src/RequestProcessing/RequestContext.FeatureCollection.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ private enum Fields
9393
TraceIdentifier = 0x200,
9494
}
9595

96-
public void InitializeFeatures()
96+
protected internal void InitializeFeatures()
9797
{
9898
_initialized = true;
9999

src/Servers/HttpSys/src/RequestProcessing/RequestContextOfT.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ protected override async Task ExecuteAsync()
2525

2626
try
2727
{
28+
InitializeFeatures();
29+
2830
if (messagePump.Stopping)
2931
{
3032
SetFatalResponse(503);

src/Servers/HttpSys/test/FunctionalTests/Listener/Utilities.cs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,23 @@ internal static HttpSysListener CreateServerOnExistingQueue(AuthenticationScheme
114114
internal static async Task<RequestContext> AcceptAsync(this HttpSysListener server, TimeSpan timeout)
115115
{
116116
var factory = new TestRequestContextFactory(server);
117-
var acceptContext = new AsyncAcceptContext(server, factory);
118-
var acceptTask = server.AcceptAsync(acceptContext).AsTask();
117+
using var acceptContext = new AsyncAcceptContext(server, factory);
118+
119+
async Task<RequestContext> AcceptAsync()
120+
{
121+
while (true)
122+
{
123+
var requestContext = await server.AcceptAsync(acceptContext);
124+
125+
if (server.ValidateRequest(requestContext))
126+
{
127+
requestContext.InitializeFeatures();
128+
return requestContext;
129+
}
130+
}
131+
}
132+
133+
var acceptTask = AcceptAsync();
119134
var completedTask = await Task.WhenAny(acceptTask, Task.Delay(timeout));
120135

121136
if (completedTask == acceptTask)

0 commit comments

Comments
 (0)