Skip to content

Commit bfe312a

Browse files
authored
Support WriteAsync cancellation token (#1645)
1 parent 2951655 commit bfe312a

31 files changed

+1302
-124
lines changed

perf/Grpc.AspNetCore.Microbenchmarks/Internal/MessageHelpers.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public static void WriteMessage<T>(Stream stream, T message, HttpContextServerCa
3030
{
3131
var pipeWriter = PipeWriter.Create(stream);
3232

33-
PipeExtensions.WriteMessageAsync(pipeWriter, message, callContext ?? HttpContextServerCallContextHelper.CreateServerCallContext(), (r, c) => c.Complete(r.ToByteArray()), canFlush: true).GetAwaiter().GetResult();
33+
PipeExtensions.WriteStreamedMessageAsync(pipeWriter, message, callContext ?? HttpContextServerCallContextHelper.CreateServerCallContext(), (r, c) => c.Complete(r.ToByteArray())).GetAwaiter().GetResult();
3434
}
3535
}
3636
}

src/Grpc.AspNetCore.Server/Internal/CallHandlers/ClientStreamingServerCallHandler.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC
7575
}
7676

7777
var responseBodyWriter = httpContext.Response.BodyWriter;
78-
await responseBodyWriter.WriteMessageAsync(response, serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer, canFlush: false);
78+
await responseBodyWriter.WriteSingleMessageAsync(response, serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer);
7979
}
8080
}
8181
}

src/Grpc.AspNetCore.Server/Internal/CallHandlers/UnaryServerCallHandler.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ protected override async Task HandleCallAsyncCore(HttpContext httpContext, HttpC
6464
}
6565

6666
var responseBodyWriter = httpContext.Response.BodyWriter;
67-
await responseBodyWriter.WriteMessageAsync(response, serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer, canFlush: false);
67+
await responseBodyWriter.WriteSingleMessageAsync(response, serverCallContext, MethodInvoker.Method.ResponseMarshaller.ContextualSerializer);
6868
}
6969
}
7070
}

src/Grpc.AspNetCore.Server/Internal/HttpContextServerCallContext.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,11 @@ internal async Task DeadlineExceededAsync()
476476
await completionFeature.CompleteAsync();
477477
}
478478

479+
CancelRequest();
480+
}
481+
482+
internal void CancelRequest()
483+
{
479484
// HttpResetFeature should always be set on context,
480485
// but in case it isn't, fall back to HttpContext.Abort.
481486
// Abort will send error code INTERNAL_ERROR.

src/Grpc.AspNetCore.Server/Internal/HttpContextStreamWriter.cs

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#endregion
1818

1919
using Grpc.Core;
20+
using Microsoft.AspNetCore.Http;
21+
using Microsoft.Extensions.Logging;
2022

2123
namespace Grpc.AspNetCore.Server.Internal
2224
{
@@ -43,30 +45,61 @@ public WriteOptions? WriteOptions
4345
}
4446

4547
public Task WriteAsync(TResponse message)
48+
{
49+
return WriteCoreAsync(message, CancellationToken.None);
50+
}
51+
52+
#if NET5_0_OR_GREATER
53+
// Explicit implementation because this WriteAsync has a default interface implementation.
54+
Task IAsyncStreamWriter<TResponse>.WriteAsync(TResponse message, CancellationToken cancellationToken)
55+
{
56+
return WriteCoreAsync(message, cancellationToken);
57+
}
58+
#endif
59+
60+
private async Task WriteCoreAsync(TResponse message, CancellationToken cancellationToken)
4661
{
4762
if (message == null)
4863
{
49-
return Task.FromException(new ArgumentNullException(nameof(message)));
64+
throw new ArgumentNullException(nameof(message));
5065
}
5166

52-
if (_completed || _context.CancellationToken.IsCancellationRequested)
67+
// Register cancellation token early to ensure request is canceled if cancellation is requested.
68+
CancellationTokenRegistration? registration = null;
69+
if (cancellationToken.CanBeCanceled)
5370
{
54-
return Task.FromException(new InvalidOperationException("Can't write the message because the request is complete."));
71+
registration = cancellationToken.Register(
72+
static (state) => ((HttpContextServerCallContext)state!).CancelRequest(),
73+
_context);
5574
}
5675

57-
lock (_writeLock)
76+
try
5877
{
59-
// Pending writes need to be awaited first
60-
if (IsWriteInProgressUnsynchronized)
78+
cancellationToken.ThrowIfCancellationRequested();
79+
80+
if (_completed || _context.CancellationToken.IsCancellationRequested)
6181
{
62-
return Task.FromException(new InvalidOperationException("Can't write the message because the previous write is in progress."));
82+
throw new InvalidOperationException("Can't write the message because the request is complete.");
6383
}
6484

65-
// Save write task to track whether it is complete. Must be set inside lock.
66-
_writeTask = _context.HttpContext.Response.BodyWriter.WriteMessageAsync(message, _context, _serializer, canFlush: true);
67-
}
85+
lock (_writeLock)
86+
{
87+
// Pending writes need to be awaited first
88+
if (IsWriteInProgressUnsynchronized)
89+
{
90+
throw new InvalidOperationException("Can't write the message because the previous write is in progress.");
91+
}
92+
93+
// Save write task to track whether it is complete. Must be set inside lock.
94+
_writeTask = _context.HttpContext.Response.BodyWriter.WriteStreamedMessageAsync(message, _context, _serializer, cancellationToken);
95+
}
6896

69-
return _writeTask;
97+
await _writeTask;
98+
}
99+
finally
100+
{
101+
registration?.Dispose();
102+
}
70103
}
71104

72105
public void Complete()

src/Grpc.AspNetCore.Server/Internal/PipeExtensions.cs

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,37 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
4343
return new Status(StatusCode.Unimplemented, $"Unsupported grpc-encoding value '{unsupportedEncoding}'. Supported encodings: {string.Join(", ", supportedEncodings)}");
4444
}
4545

46-
public static async Task WriteMessageAsync<TResponse>(this PipeWriter pipeWriter, TResponse response, HttpContextServerCallContext serverCallContext, Action<TResponse, SerializationContext> serializer, bool canFlush)
46+
public static async Task WriteSingleMessageAsync<TResponse>(this PipeWriter pipeWriter, TResponse response, HttpContextServerCallContext serverCallContext, Action<TResponse, SerializationContext> serializer)
47+
where TResponse : class
48+
{
49+
var logger = serverCallContext.Logger;
50+
try
51+
{
52+
// Must call StartAsync before the first pipeWriter.GetSpan() in WriteHeader
53+
var httpResponse = serverCallContext.HttpContext.Response;
54+
if (!httpResponse.HasStarted)
55+
{
56+
await httpResponse.StartAsync();
57+
}
58+
59+
GrpcServerLog.SendingMessage(logger);
60+
61+
var serializationContext = serverCallContext.SerializationContext;
62+
serializationContext.Reset();
63+
serializationContext.ResponseBufferWriter = pipeWriter;
64+
serializer(response, serializationContext);
65+
66+
GrpcServerLog.MessageSent(serverCallContext.Logger);
67+
GrpcEventSource.Log.MessageSent();
68+
}
69+
catch (Exception ex)
70+
{
71+
GrpcServerLog.ErrorSendingMessage(logger, ex);
72+
throw;
73+
}
74+
}
75+
76+
public static async Task WriteStreamedMessageAsync<TResponse>(this PipeWriter pipeWriter, TResponse response, HttpContextServerCallContext serverCallContext, Action<TResponse, SerializationContext> serializer, CancellationToken cancellationToken = default)
4777
where TResponse : class
4878
{
4979
var logger = serverCallContext.Logger;
@@ -64,11 +94,20 @@ public static async Task WriteMessageAsync<TResponse>(this PipeWriter pipeWriter
6494
serializer(response, serializationContext);
6595

6696
// Flush messages unless WriteOptions.Flags has BufferHint set
67-
var flush = canFlush && ((serverCallContext.WriteOptions?.Flags ?? default) & WriteFlags.BufferHint) != WriteFlags.BufferHint;
97+
var flush = ((serverCallContext.WriteOptions?.Flags ?? default) & WriteFlags.BufferHint) != WriteFlags.BufferHint;
6898

6999
if (flush)
70100
{
71-
await pipeWriter.FlushAsync();
101+
var flushResult = await pipeWriter.FlushAsync();
102+
103+
// Workaround bug where FlushAsync doesn't return IsCanceled = true on request abort.
104+
// https://github.com/dotnet/aspnetcore/issues/40788
105+
// Also, sometimes the request CT isn't triggered. Also check CT passed into method.
106+
if (!flushResult.IsCompleted &&
107+
(serverCallContext.CancellationToken.IsCancellationRequested || cancellationToken.IsCancellationRequested))
108+
{
109+
throw new OperationCanceledException("Request aborted while sending the message.");
110+
}
72111
}
73112

74113
GrpcServerLog.MessageSent(serverCallContext.Logger);

src/Grpc.Net.Client/Internal/ClientStreamWriterBase.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,17 @@ protected ClientStreamWriterBase(ILogger logger)
4040

4141
public abstract Task CompleteAsync();
4242

43-
public abstract Task WriteAsync(TRequest message);
43+
public Task WriteAsync(TRequest message) => WriteCoreAsync(message, CancellationToken.None);
44+
45+
#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER
46+
// Explicit implementation because this WriteAsync has a default interface implementation.
47+
Task IAsyncStreamWriter<TRequest>.WriteAsync(TRequest message, CancellationToken cancellationToken)
48+
{
49+
return WriteCoreAsync(message, cancellationToken);
50+
}
51+
#endif
52+
53+
public abstract Task WriteCoreAsync(TRequest message, CancellationToken cancellationToken);
4454

4555
protected Task CreateErrorTask(string message)
4656
{

src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ public DefaultDeserializationContext DeserializationContext
5353

5454
public string? RequestGrpcEncoding { get; internal set; }
5555

56+
public abstract CancellationToken CancellationToken { get; }
5657
public abstract Type RequestType { get; }
5758
public abstract Type ResponseType { get; }
5859

src/Grpc.Net.Client/Internal/GrpcCall.cs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,7 @@ private void ValidateDeadline(DateTime? deadline)
9292

9393
public Task<Status> CallTask => _callTcs.Task;
9494

95-
public CancellationToken CancellationToken
96-
{
97-
get { return _callCts.Token; }
98-
}
95+
public override CancellationToken CancellationToken => _callCts.Token;
9996

10097
public override Type RequestType => typeof(TRequest);
10198
public override Type ResponseType => typeof(TResponse);
@@ -379,7 +376,26 @@ private void SetMessageContent(HttpContent content, HttpRequestMessage message)
379376
message.Content = content;
380377
}
381378

382-
public void CancelCallFromCancellationToken()
379+
public bool TryRegisterCancellation(
380+
CancellationToken cancellationToken,
381+
[NotNullWhen(true)] out CancellationTokenRegistration? cancellationTokenRegistration)
382+
{
383+
// Only register if the token:
384+
// 1. Can be canceled.
385+
// 2. The token isn't the same one used in CallOptions. Already listening for its cancellation.
386+
if (cancellationToken.CanBeCanceled && cancellationToken != Options.CancellationToken)
387+
{
388+
cancellationTokenRegistration = cancellationToken.Register(
389+
static (state) => ((GrpcCall<TRequest, TResponse>)state!).CancelCallFromCancellationToken(),
390+
this);
391+
return true;
392+
}
393+
394+
cancellationTokenRegistration = null;
395+
return false;
396+
}
397+
398+
private void CancelCallFromCancellationToken()
383399
{
384400
using (StartScope())
385401
{
@@ -726,7 +742,8 @@ private void SetFailedResult(Status status)
726742

727743
public Exception CreateFailureStatusException(Status status)
728744
{
729-
if (Channel.ThrowOperationCanceledOnCancellation && status.StatusCode == StatusCode.DeadlineExceeded)
745+
if (Channel.ThrowOperationCanceledOnCancellation &&
746+
(status.StatusCode == StatusCode.DeadlineExceeded || status.StatusCode == StatusCode.Cancelled))
730747
{
731748
// Convert status response of DeadlineExceeded to OperationCanceledException when
732749
// ThrowOperationCanceledOnCancellation is true.
@@ -792,7 +809,9 @@ public Exception CreateFailureStatusException(Status status)
792809
// The cancellation token will cancel the call CTS.
793810
// This must be registered after the client writer has been created
794811
// so that cancellation will always complete the writer.
795-
_ctsRegistration = Options.CancellationToken.Register(CancelCallFromCancellationToken);
812+
_ctsRegistration = Options.CancellationToken.Register(
813+
static (state) => ((GrpcCall<TRequest, TResponse>)state!).CancelCallFromCancellationToken(),
814+
this);
796815
}
797816

798817
return (diagnosticSourceEnabled, activity);

src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,9 @@ public Task<bool> MoveNext(CancellationToken cancellationToken)
109109

110110
private async Task<bool> MoveNextCore(CancellationToken cancellationToken)
111111
{
112-
CancellationTokenRegistration? ctsRegistration = null;
112+
_call.TryRegisterCancellation(cancellationToken, out var ctsRegistration);
113113
try
114114
{
115-
if (cancellationToken.CanBeCanceled)
116-
{
117-
// The cancellation token will cancel the call CTS.
118-
ctsRegistration = cancellationToken.Register(_call.CancelCallFromCancellationToken);
119-
}
120-
121115
_call.CancellationToken.ThrowIfCancellationRequested();
122116

123117
if (_httpResponse == null)

0 commit comments

Comments
 (0)