Skip to content

Commit 91be392

Browse files
authored
Fix HttpContext race by copying values to reader and writer (#2294)
1 parent 683fbdf commit 91be392

File tree

4 files changed

+105
-5
lines changed

4 files changed

+105
-5
lines changed

src/Grpc.AspNetCore.Server/Internal/GrpcProtocolHelpers.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -234,4 +234,17 @@ internal static bool ShouldSkipHeader(string name)
234234
{
235235
return name.StartsWith(':') || GrpcProtocolConstants.FilteredHeaders.Contains(name);
236236
}
237+
238+
internal static IHttpRequestLifetimeFeature GetRequestLifetimeFeature(HttpContext httpContext)
239+
{
240+
var lifetimeFeature = httpContext.Features.Get<IHttpRequestLifetimeFeature>();
241+
if (lifetimeFeature is null)
242+
{
243+
// This should only run in tests where the HttpContext is manually created.
244+
lifetimeFeature = new HttpRequestLifetimeFeature();
245+
httpContext.Features.Set(lifetimeFeature);
246+
}
247+
248+
return lifetimeFeature;
249+
}
237250
}

src/Grpc.AspNetCore.Server/Internal/HttpContextStreamReader.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
#endregion
1818

1919
using System.Diagnostics;
20+
using System.IO.Pipelines;
2021
using Grpc.Core;
2122
using Grpc.Shared;
23+
using Microsoft.AspNetCore.Http.Features;
2224

2325
namespace Grpc.AspNetCore.Server.Internal;
2426

@@ -28,6 +30,8 @@ internal class HttpContextStreamReader<TRequest> : IAsyncStreamReader<TRequest>
2830
{
2931
private readonly HttpContextServerCallContext _serverCallContext;
3032
private readonly Func<DeserializationContext, TRequest> _deserializer;
33+
private readonly PipeReader _bodyReader;
34+
private readonly IHttpRequestLifetimeFeature _requestLifetimeFeature;
3135
private bool _completed;
3236
private long _readCount;
3337
private bool _endOfStream;
@@ -36,6 +40,12 @@ public HttpContextStreamReader(HttpContextServerCallContext serverCallContext, F
3640
{
3741
_serverCallContext = serverCallContext;
3842
_deserializer = deserializer;
43+
44+
// Copy HttpContext values.
45+
// This is done to avoid a race condition when reading them from HttpContext later when running in a separate thread.
46+
_bodyReader = _serverCallContext.HttpContext.Request.BodyReader;
47+
// Copy lifetime feature because HttpContext.RequestAborted on .NET 6 doesn't return the real cancellation token.
48+
_requestLifetimeFeature = GrpcProtocolHelpers.GetRequestLifetimeFeature(_serverCallContext.HttpContext);
3949
}
4050

4151
public TRequest Current { get; private set; } = default!;
@@ -54,7 +64,7 @@ async Task<bool> MoveNextAsync(ValueTask<TRequest?> readStreamTask)
5464
return Task.FromCanceled<bool>(cancellationToken);
5565
}
5666

57-
if (_completed || _serverCallContext.CancellationToken.IsCancellationRequested)
67+
if (_completed || _requestLifetimeFeature.RequestAborted.IsCancellationRequested)
5868
{
5969
return Task.FromException<bool>(new InvalidOperationException("Can't read messages after the request is complete."));
6070
}
@@ -63,7 +73,7 @@ async Task<bool> MoveNextAsync(ValueTask<TRequest?> readStreamTask)
6373
// In a long running stream this can allow the previous value to be GCed.
6474
Current = null!;
6575

66-
var request = _serverCallContext.HttpContext.Request.BodyReader.ReadStreamMessageAsync(_serverCallContext, _deserializer, cancellationToken);
76+
var request = _bodyReader.ReadStreamMessageAsync(_serverCallContext, _deserializer, cancellationToken);
6777
if (!request.IsCompletedSuccessfully)
6878
{
6979
return MoveNextAsync(request);

src/Grpc.AspNetCore.Server/Internal/HttpContextStreamWriter.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
#endregion
1818

1919
using System.Diagnostics;
20+
using System.IO.Pipelines;
2021
using Grpc.Core;
2122
using Grpc.Shared;
23+
using Microsoft.AspNetCore.Http.Features;
2224

2325
namespace Grpc.AspNetCore.Server.Internal;
2426

@@ -29,6 +31,8 @@ internal class HttpContextStreamWriter<TResponse> : IServerStreamWriter<TRespons
2931
{
3032
private readonly HttpContextServerCallContext _context;
3133
private readonly Action<TResponse, SerializationContext> _serializer;
34+
private readonly PipeWriter _bodyWriter;
35+
private readonly IHttpRequestLifetimeFeature _requestLifetimeFeature;
3236
private readonly object _writeLock;
3337
private Task? _writeTask;
3438
private bool _completed;
@@ -39,6 +43,12 @@ public HttpContextStreamWriter(HttpContextServerCallContext context, Action<TRes
3943
_context = context;
4044
_serializer = serializer;
4145
_writeLock = new object();
46+
47+
// Copy HttpContext values.
48+
// This is done to avoid a race condition when reading them from HttpContext later when running in a separate thread.
49+
_bodyWriter = context.HttpContext.Response.BodyWriter;
50+
// Copy lifetime feature because HttpContext.RequestAborted on .NET 6 doesn't return the real cancellation token.
51+
_requestLifetimeFeature = GrpcProtocolHelpers.GetRequestLifetimeFeature(context.HttpContext);
4252
}
4353

4454
public WriteOptions? WriteOptions
@@ -77,7 +87,7 @@ private async Task WriteCoreAsync(TResponse message, CancellationToken cancellat
7787
{
7888
cancellationToken.ThrowIfCancellationRequested();
7989

80-
if (_completed || _context.CancellationToken.IsCancellationRequested)
90+
if (_completed || _requestLifetimeFeature.RequestAborted.IsCancellationRequested)
8191
{
8292
throw new InvalidOperationException("Can't write the message because the request is complete.");
8393
}
@@ -91,7 +101,7 @@ private async Task WriteCoreAsync(TResponse message, CancellationToken cancellat
91101
}
92102

93103
// Save write task to track whether it is complete. Must be set inside lock.
94-
_writeTask = _context.HttpContext.Response.BodyWriter.WriteStreamedMessageAsync(message, _context, _serializer, cancellationToken);
104+
_writeTask = _bodyWriter.WriteStreamedMessageAsync(message, _context, _serializer, cancellationToken);
95105
}
96106

97107
await _writeTask;
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#region Copyright notice and license
2+
3+
// Copyright 2019 The gRPC Authors
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
#endregion
18+
19+
using Grpc.AspNetCore.Server.Internal.CallHandlers;
20+
using Grpc.AspNetCore.Server.Tests.TestObjects;
21+
using Grpc.Core;
22+
using Grpc.Shared.Server;
23+
using Grpc.Tests.Shared;
24+
using Microsoft.AspNetCore.Http;
25+
using Microsoft.AspNetCore.Http.Features;
26+
using Microsoft.Extensions.Logging.Abstractions;
27+
using NUnit.Framework;
28+
29+
namespace Grpc.AspNetCore.Server.Tests;
30+
31+
[TestFixture]
32+
public class DuplexStreamingServerCallHandlerTests
33+
{
34+
private static readonly Marshaller<TestMessage> _marshaller = new Marshaller<TestMessage>((message, context) => { context.Complete(Array.Empty<byte>()); }, context => new TestMessage());
35+
36+
[Test]
37+
public async Task HandleCallAsync_ConcurrentReadAndWrite_Success()
38+
{
39+
// Arrange
40+
var invoker = new DuplexStreamingServerMethodInvoker<TestService, TestMessage, TestMessage>(
41+
(service, reader, writer, context) =>
42+
{
43+
var message = new TestMessage();
44+
var readTask = Task.Run(() => reader.MoveNext());
45+
var writeTask = Task.Run(() => writer.WriteAsync(message));
46+
return Task.WhenAll(readTask, writeTask);
47+
},
48+
new Method<TestMessage, TestMessage>(MethodType.DuplexStreaming, "test", "test", _marshaller, _marshaller),
49+
HttpContextServerCallContextHelper.CreateMethodOptions(),
50+
new TestGrpcServiceActivator<TestService>());
51+
var handler = new DuplexStreamingServerCallHandler<TestService, TestMessage, TestMessage>(invoker, NullLoggerFactory.Instance);
52+
53+
// Verify there isn't a race condition when reading/writing on seperate threads.
54+
// This test primarily exists to ensure that the stream reader and stream writer aren't accessing non-thread safe APIs on HttpContext.
55+
for (var i = 0; i < 10_000; i++)
56+
{
57+
var httpContext = HttpContextHelpers.CreateContext();
58+
59+
// Act
60+
await handler.HandleCallAsync(httpContext);
61+
62+
// Assert
63+
var trailers = httpContext.Features.Get<IHttpResponseTrailersFeature>()!.Trailers;
64+
Assert.AreEqual("0", trailers["grpc-status"].ToString());
65+
}
66+
}
67+
}

0 commit comments

Comments
 (0)