Skip to content

Commit 9bc4d27

Browse files
authored
Check for in-progress write in ServerStreamWriter (#903)
1 parent 41618a4 commit 9bc4d27

File tree

4 files changed

+68
-6
lines changed

4 files changed

+68
-6
lines changed

src/Grpc.AspNetCore.Server/Internal/HttpContextStreamWriter.cs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,14 @@ internal class HttpContextStreamWriter<TResponse> : IServerStreamWriter<TRespons
2727
{
2828
private readonly HttpContextServerCallContext _context;
2929
private readonly Action<TResponse, SerializationContext> _serializer;
30+
private readonly object _writeLock;
31+
private Task? _writeTask;
3032

3133
public HttpContextStreamWriter(HttpContextServerCallContext context, Action<TResponse, SerializationContext> serializer)
3234
{
3335
_context = context;
3436
_serializer = serializer;
37+
_writeLock = new object();
3538
}
3639

3740
public WriteOptions WriteOptions
@@ -44,15 +47,40 @@ public Task WriteAsync(TResponse message)
4447
{
4548
if (message == null)
4649
{
47-
throw new ArgumentNullException(nameof(message));
50+
return Task.FromException(new ArgumentNullException(nameof(message)));
4851
}
4952

5053
if (_context.CancellationToken.IsCancellationRequested)
5154
{
52-
throw new InvalidOperationException("Cannot write message after request is complete.");
55+
return Task.FromException(new InvalidOperationException("Cannot write message after request is complete."));
5356
}
5457

55-
return _context.HttpContext.Response.BodyWriter.WriteMessageAsync(message, _context, _serializer, canFlush: true);
58+
lock (_writeLock)
59+
{
60+
// Pending writes need to be awaited first
61+
if (IsWriteInProgressUnsynchronized)
62+
{
63+
return Task.FromException(new InvalidOperationException("Can't write the message because the previous write is in progress."));
64+
}
65+
66+
// Save write task to track whether it is complete. Must be set inside lock.
67+
_writeTask = _context.HttpContext.Response.BodyWriter.WriteMessageAsync(message, _context, _serializer, canFlush: true);
68+
}
69+
70+
return _writeTask;
71+
}
72+
73+
/// <summary>
74+
/// A value indicating whether there is an async write already in progress.
75+
/// Should only check this property when holding the write lock.
76+
/// </summary>
77+
private bool IsWriteInProgressUnsynchronized
78+
{
79+
get
80+
{
81+
var writeTask = _writeTask;
82+
return writeTask != null && !writeTask.IsCompleted;
83+
}
5684
}
5785
}
5886
}

src/Grpc.Net.Client/Internal/HttpContentClientStreamWriter.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ public Task WriteAsync(TRequest message)
129129
return CreateErrorTask("Can't write the message because the previous write is in progress.");
130130
}
131131

132-
// Save write task to track whether it is complete
132+
// Save write task to track whether it is complete. Must be set inside lock.
133133
_writeTask = WriteAsyncCore(message);
134134
}
135135
}

test/Grpc.AspNetCore.Server.Tests/HttpContextStreamWriterTests.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#endregion
1818

19+
using System;
1920
using System.IO;
2021
using System.IO.Pipelines;
2122
using System.Threading.Tasks;
@@ -110,5 +111,35 @@ await writer.WriteAsync(new HelloReply
110111
var writtenMessage2 = await MessageHelpers.AssertReadStreamMessageAsync<HelloReply>(pipeReader);
111112
Assert.AreEqual("Hello world 2", writtenMessage2!.Message);
112113
}
114+
115+
[Test]
116+
public async Task WriteAsync_WriteInProgress_Error()
117+
{
118+
// Arrange
119+
var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
120+
121+
var httpContext = new DefaultHttpContext();
122+
httpContext.Features.Set<IHttpResponseBodyFeature>(new TestResponseBodyFeature(PipeWriter.Create(new MemoryStream()), startAsyncTask: tcs.Task));
123+
var serverCallContext = HttpContextServerCallContextHelper.CreateServerCallContext(httpContext);
124+
var writer = new HttpContextStreamWriter<HelloReply>(serverCallContext, MessageHelpers.ServiceMethod.ResponseMarshaller.ContextualSerializer);
125+
126+
// Act
127+
_ = writer.WriteAsync(new HelloReply
128+
{
129+
Message = "Hello world 1"
130+
});
131+
132+
var ex = await ExceptionAssert.ThrowsAsync<InvalidOperationException>(() =>
133+
{
134+
return writer.WriteAsync(new HelloReply
135+
{
136+
Message = "Hello world 2"
137+
});
138+
});
139+
140+
// Assert
141+
Assert.AreEqual("Can't write the message because the previous write is in progress.", ex.Message);
142+
}
143+
113144
}
114145
}

test/Shared/TestResponseBodyFeature.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@ namespace Grpc.Tests.Shared
2727
{
2828
public class TestResponseBodyFeature : IHttpResponseBodyFeature
2929
{
30-
public TestResponseBodyFeature(PipeWriter writer)
30+
private readonly Task _startAsyncTask;
31+
32+
public TestResponseBodyFeature(PipeWriter writer, Task? startAsyncTask = null)
3133
{
3234
Writer = writer;
35+
_startAsyncTask = startAsyncTask ?? Task.CompletedTask;
3336
}
3437

3538
public PipeWriter Writer { get; }
@@ -52,7 +55,7 @@ public Task SendFileAsync(string path, long offset, long? count, CancellationTok
5255

5356
public Task StartAsync(CancellationToken cancellationToken = default)
5457
{
55-
return Task.CompletedTask;
58+
return _startAsyncTask;
5659
}
5760
}
5861
}

0 commit comments

Comments
 (0)