Skip to content

Commit a2918b0

Browse files
authored
Client response MoveNext cancellation cancels call (#1062)
2 parents 5395000 + 86f4600 commit a2918b0

File tree

4 files changed

+76
-35
lines changed

4 files changed

+76
-35
lines changed

src/Grpc.Net.Client/Internal/GrpcCall.cs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,14 @@ private void SetMessageContent(TRequest request, HttpRequestMessage message)
388388
GrpcProtocolConstants.GrpcContentTypeHeaderValue);
389389
}
390390

391+
public void CancelCallFromCancellationToken()
392+
{
393+
using (StartScope())
394+
{
395+
CancelCall(new Status(StatusCode.Cancelled, "Call canceled by the client."));
396+
}
397+
}
398+
391399
private void CancelCall(Status status)
392400
{
393401
// Set overall call status first. Status can be used in throw RpcException from cancellation.
@@ -690,13 +698,7 @@ public Exception CreateFailureStatusException(Status status)
690698
// The cancellation token will cancel the call CTS.
691699
// This must be registered after the client writer has been created
692700
// so that cancellation will always complete the writer.
693-
_ctsRegistration = Options.CancellationToken.Register(() =>
694-
{
695-
using (StartScope())
696-
{
697-
CancelCall(new Status(StatusCode.Cancelled, "Call canceled by the client."));
698-
}
699-
});
701+
_ctsRegistration = Options.CancellationToken.Register(CancelCallFromCancellationToken);
700702
}
701703

702704
return (diagnosticSourceEnabled, activity);

src/Grpc.Net.Client/Internal/HttpContentClientStreamReader.cs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,21 +119,16 @@ public Task<bool> MoveNext(CancellationToken cancellationToken)
119119

120120
private async Task<bool> MoveNextCore(CancellationToken cancellationToken)
121121
{
122-
CancellationTokenSource? cts = null;
122+
CancellationTokenRegistration? ctsRegistration = null;
123123
try
124124
{
125-
// Linking tokens is expensive. Only create a linked token if the token passed in requires it
126125
if (cancellationToken.CanBeCanceled)
127126
{
128-
cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _call.CancellationToken);
129-
cancellationToken = cts.Token;
130-
}
131-
else
132-
{
133-
cancellationToken = _call.CancellationToken;
127+
// The cancellation token will cancel the call CTS.
128+
ctsRegistration = cancellationToken.Register(_call.CancelCallFromCancellationToken);
134129
}
135130

136-
cancellationToken.ThrowIfCancellationRequested();
131+
_call.CancellationToken.ThrowIfCancellationRequested();
137132

138133
if (_httpResponse == null)
139134
{
@@ -167,7 +162,7 @@ private async Task<bool> MoveNextCore(CancellationToken cancellationToken)
167162
_responseStream,
168163
_grpcEncoding,
169164
singleMessage: false,
170-
cancellationToken).ConfigureAwait(false);
165+
_call.CancellationToken).ConfigureAwait(false);
171166
if (Current == null)
172167
{
173168
// No more content in response so report status to call.
@@ -202,7 +197,7 @@ private async Task<bool> MoveNextCore(CancellationToken cancellationToken)
202197
}
203198
finally
204199
{
205-
cts?.Dispose();
200+
ctsRegistration?.Dispose();
206201
}
207202
}
208203

test/FunctionalTests/Client/StreamingTests.cs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
using System.Diagnostics;
2222
using System.IO;
2323
using System.Linq;
24+
using System.Text;
2425
using System.Threading;
2526
using System.Threading.Tasks;
2627
using Google.Protobuf;
@@ -515,6 +516,53 @@ await context.WriteResponseHeadersAsync(new Metadata
515516
Assert.AreEqual("Message", call.GetStatus().Detail);
516517
}
517518

519+
[Test]
520+
public async Task DuplexStreaming_CancelResponseMoveNext_CancellationSentToServer()
521+
{
522+
var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
523+
524+
async Task DuplexStreamingWithCancellation(IAsyncStreamReader<DataMessage> requestStream, IServerStreamWriter<DataMessage> responseStream, ServerCallContext context)
525+
{
526+
try
527+
{
528+
await foreach (var message in requestStream.ReadAllAsync())
529+
{
530+
await responseStream.WriteAsync(message);
531+
}
532+
}
533+
catch (Exception ex)
534+
{
535+
tcs.TrySetException(ex);
536+
}
537+
}
538+
539+
// Arrange
540+
var method = Fixture.DynamicGrpc.AddDuplexStreamingMethod<DataMessage, DataMessage>(DuplexStreamingWithCancellation);
541+
542+
var channel = CreateChannel();
543+
544+
var client = TestClientFactory.Create(channel, method);
545+
546+
// Act
547+
var call = client.DuplexStreamingCall();
548+
549+
await call.RequestStream.WriteAsync(new DataMessage { Data = ByteString.CopyFrom(Encoding.UTF8.GetBytes("Hello world")) });
550+
551+
await call.ResponseStream.MoveNext();
552+
553+
var cts = new CancellationTokenSource();
554+
var task = call.ResponseStream.MoveNext(cts.Token);
555+
556+
cts.Cancel();
557+
558+
// Assert
559+
var clientEx = await ExceptionAssert.ThrowsAsync<RpcException>(() => task);
560+
Assert.AreEqual(StatusCode.Cancelled, clientEx.StatusCode);
561+
Assert.AreEqual("Call canceled by the client.", clientEx.Status.Detail);
562+
563+
await ExceptionAssert.ThrowsAsync<IOException>(() => tcs.Task);
564+
}
565+
518566
private static byte[] CreateTestData(int size)
519567
{
520568
var data = new byte[size];

test/FunctionalTests/Web/Server/DeadlineTests.cs

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -86,28 +86,24 @@ static async Task<HelloReply> WaitUntilDeadline(HelloRequest request, ServerCall
8686

8787
var grpcWebClient = CreateGrpcWebClient();
8888

89-
// TODO(JamesNK): This test is/was flaky. Remove loop if this test is no longer a problem
90-
for (int i = 0; i < 20; i++)
89+
var requestMessage = new HelloRequest
9190
{
92-
var requestMessage = new HelloRequest
93-
{
94-
Name = "World"
95-
};
91+
Name = "World"
92+
};
9693

97-
var requestStream = new MemoryStream();
98-
MessageHelpers.WriteMessage(requestStream, requestMessage);
94+
var requestStream = new MemoryStream();
95+
MessageHelpers.WriteMessage(requestStream, requestMessage);
9996

100-
var httpRequest = GrpcHttpHelper.Create(method.FullName);
101-
httpRequest.Headers.Add(GrpcProtocolConstants.TimeoutHeader, "50m");
102-
httpRequest.Content = new GrpcStreamContent(requestStream);
97+
var httpRequest = GrpcHttpHelper.Create(method.FullName);
98+
httpRequest.Headers.Add(GrpcProtocolConstants.TimeoutHeader, "50m");
99+
httpRequest.Content = new GrpcStreamContent(requestStream);
103100

104-
// Act
105-
var response = await grpcWebClient.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).DefaultTimeout();
101+
// Act
102+
var response = await grpcWebClient.SendAsync(httpRequest, HttpCompletionOption.ResponseHeadersRead).DefaultTimeout();
106103

107-
// Assert
108-
response.AssertIsSuccessfulGrpcRequest();
109-
response.AssertTrailerStatus(StatusCode.DeadlineExceeded, "Deadline Exceeded");
110-
}
104+
// Assert
105+
response.AssertIsSuccessfulGrpcRequest();
106+
response.AssertTrailerStatus(StatusCode.DeadlineExceeded, "Deadline Exceeded");
111107
}
112108

113109
[Test]

0 commit comments

Comments
 (0)