Skip to content

Commit a236023

Browse files
authored
Fix race between deadline and Stream.ReadAsync (#1550)
1 parent 1ac3a3f commit a236023

File tree

2 files changed

+163
-2
lines changed

2 files changed

+163
-2
lines changed

src/Grpc.Net.Client/Internal/StreamExtensions.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
109109
buffer = ArrayPool<byte>.Shared.Rent(length);
110110
}
111111

112-
await ReadMessageContent(responseStream, buffer, length, cancellationToken).ConfigureAwait(false);
112+
await ReadMessageContentAsync(responseStream, buffer, length, cancellationToken).ConfigureAwait(false);
113113
}
114114

115115
cancellationToken.ThrowIfCancellationRequested();
@@ -161,6 +161,16 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
161161
GrpcCallLog.ReceivedMessage(call.Logger);
162162
return message;
163163
}
164+
catch (ObjectDisposedException) when (cancellationToken.IsCancellationRequested)
165+
{
166+
// When a deadline expires there can be a race between cancellation and Stream.ReadAsync.
167+
// If ReadAsync is called after the response is disposed then ReadAsync throws ObjectDisposedException.
168+
// https://github.com/dotnet/runtime/blob/dfbae37e91c4744822018dde10cbd414c661c0b8/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Stream.cs#L1479-L1482
169+
//
170+
// If ObjectDisposedException is caught and cancellation has happened then rethrow as an OCE.
171+
// This makes gRPC client correctly report a DeadlineExceeded status.
172+
throw new OperationCanceledException();
173+
}
164174
catch (Exception ex) when (!(ex is OperationCanceledException && cancellationToken.IsCancellationRequested))
165175
{
166176
// Don't write error when user cancels read
@@ -216,7 +226,7 @@ private static int ReadMessageLength(Span<byte> header)
216226
return (int)length;
217227
}
218228

219-
private static async Task ReadMessageContent(Stream responseStream, Memory<byte> messageData, int length, CancellationToken cancellationToken)
229+
private static async Task ReadMessageContentAsync(Stream responseStream, Memory<byte> messageData, int length, CancellationToken cancellationToken)
220230
{
221231
// Read message content until content length is reached
222232
var received = 0;

test/FunctionalTests/Client/DeadlineTests.cs

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
#endregion
1818

19+
using System.Net;
1920
using Grpc.AspNetCore.FunctionalTests.Infrastructure;
2021
using Grpc.Core;
22+
using Grpc.Net.Client;
2123
using Grpc.Tests.Shared;
2224
using Microsoft.AspNetCore.Http.Features;
2325
using NUnit.Framework;
@@ -156,5 +158,154 @@ async Task ServerStreamingTimeout(DataMessage request, IServerStreamWriter<DataM
156158
Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode);
157159
}
158160
}
161+
162+
[Test]
163+
public async Task Unary_DeadlineInBetweenReadAsyncCalls_DeadlineExceededStatus()
164+
{
165+
Task<DataMessage> Unary(DataMessage request, ServerCallContext context)
166+
{
167+
return Task.FromResult(new DataMessage());
168+
}
169+
170+
// Arrange
171+
var method = Fixture.DynamicGrpc.AddUnaryMethod<DataMessage, DataMessage>(Unary);
172+
173+
var http = Fixture.CreateHandler(TestServerEndpointName.Http2);
174+
175+
var channel = GrpcChannel.ForAddress(http.address, new GrpcChannelOptions
176+
{
177+
LoggerFactory = LoggerFactory,
178+
HttpHandler = new PauseHttpHandler { InnerHandler = http.handler }
179+
});
180+
181+
var client = TestClientFactory.Create(channel, method);
182+
183+
// Act
184+
var call = client.UnaryCall(new DataMessage(), new CallOptions(deadline: DateTime.UtcNow.AddMilliseconds(200)));
185+
186+
// Assert
187+
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => call.ResponseAsync).DefaultTimeout();
188+
Assert.AreEqual(StatusCode.DeadlineExceeded, ex.StatusCode);
189+
Assert.AreEqual(StatusCode.DeadlineExceeded, call.GetStatus().StatusCode);
190+
}
191+
192+
private class PauseHttpHandler : DelegatingHandler
193+
{
194+
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
195+
{
196+
var response = await base.SendAsync(request, cancellationToken);
197+
198+
var newHttpContent = new PauseHttpContent(response.Content);
199+
newHttpContent.Headers.ContentType = response.Content.Headers.ContentType;
200+
201+
response.Content = newHttpContent;
202+
203+
return response;
204+
}
205+
206+
private class PauseHttpContent : HttpContent
207+
{
208+
private readonly HttpContent _inner;
209+
private Stream? _innerStream;
210+
211+
public PauseHttpContent(HttpContent inner)
212+
{
213+
_inner = inner;
214+
}
215+
216+
protected override async Task<Stream> CreateContentReadStreamAsync()
217+
{
218+
var stream = await _inner.ReadAsStreamAsync().ConfigureAwait(false);
219+
220+
return new PauseStream(stream);
221+
}
222+
223+
protected override async Task SerializeToStreamAsync(Stream stream, TransportContext? context)
224+
{
225+
_innerStream = await _inner.ReadAsStreamAsync().ConfigureAwait(false);
226+
227+
_innerStream = new PauseStream(_innerStream);
228+
229+
await _innerStream.CopyToAsync(stream).ConfigureAwait(false);
230+
}
231+
232+
protected override bool TryComputeLength(out long length)
233+
{
234+
length = 0;
235+
return false;
236+
}
237+
238+
protected override void Dispose(bool disposing)
239+
{
240+
if (disposing)
241+
{
242+
// This is important. Disposing original response content will cancel the gRPC call.
243+
_inner.Dispose();
244+
_innerStream?.Dispose();
245+
}
246+
247+
base.Dispose(disposing);
248+
}
249+
250+
private class PauseStream : Stream
251+
{
252+
private Stream _stream;
253+
254+
public PauseStream(Stream stream)
255+
{
256+
_stream = stream;
257+
}
258+
259+
public override bool CanRead => _stream.CanRead;
260+
public override bool CanSeek => _stream.CanSeek;
261+
public override bool CanWrite => _stream.CanWrite;
262+
public override long Length => _stream.Length;
263+
public override long Position
264+
{
265+
get => _stream.Position;
266+
set => _stream.Position = value;
267+
}
268+
269+
public override void Flush()
270+
{
271+
_stream.Flush();
272+
}
273+
274+
public override int Read(byte[] buffer, int offset, int count)
275+
{
276+
return _stream.Read(buffer, offset, count);
277+
}
278+
279+
public override long Seek(long offset, SeekOrigin origin)
280+
{
281+
return _stream.Seek(offset, origin);
282+
}
283+
284+
public override void SetLength(long value)
285+
{
286+
_stream.SetLength(value);
287+
}
288+
289+
public override void Write(byte[] buffer, int offset, int count)
290+
{
291+
_stream.Write(buffer, offset, count);
292+
}
293+
294+
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
295+
{
296+
// Wait for call to be canceled.
297+
var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
298+
cancellationToken.Register(() => tcs.SetResult(null));
299+
await tcs.Task;
300+
301+
// Wait a little longer to give time for HttpResponseMessage dispose to complete.
302+
await Task.Delay(50);
303+
304+
// Still try to read data from canceled request.
305+
return await _stream.ReadAsync(buffer, cancellationToken);
306+
}
307+
}
308+
}
309+
}
159310
}
160311
}

0 commit comments

Comments
 (0)