@@ -37,9 +37,6 @@ namespace Grpc.Net.Client
3737{
3838 internal static partial class StreamExtensions
3939 {
40- private const int MessageDelimiterSize = 4 ; // how many bytes it takes to encode "Message-Length"
41- private const int HeaderSize = MessageDelimiterSize + 1 ; // message length + compression flag
42-
4340 private static readonly Status SendingMessageExceedsLimitStatus = new Status ( StatusCode . ResourceExhausted , "Sending message exceeds the maximum configured message size." ) ;
4441 private static readonly Status ReceivedMessageExceedsLimitStatus = new Status ( StatusCode . ResourceExhausted , "Received message exceeds the maximum configured message size." ) ;
4542 private static readonly Status NoMessageEncodingMessageStatus = new Status ( StatusCode . Internal , "Request did not include grpc-encoding value with compressed message." ) ;
@@ -49,21 +46,21 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
4946 return new Status ( StatusCode . Unimplemented , $ "Unsupported grpc-encoding value '{ unsupportedEncoding } '. Supported encodings: { string . Join ( ", " , supportedEncodings ) } ") ;
5047 }
5148
52- private static async Task < ( uint length , bool compressed ) ? > ReadHeaderAsync ( Stream responseStream , Memory < byte > header , CancellationToken cancellationToken )
49+ private static async Task < ( int length , bool compressed ) ? > ReadHeaderAsync ( Stream responseStream , Memory < byte > header , CancellationToken cancellationToken )
5350 {
5451 int read ;
5552 var received = 0 ;
56- while ( ( read = await responseStream . ReadAsync ( header . Slice ( received , header . Length - received ) , cancellationToken ) . ConfigureAwait ( false ) ) > 0 )
53+ while ( ( read = await responseStream . ReadAsync ( header . Slice ( received , GrpcProtocolConstants . HeaderSize - received ) , cancellationToken ) . ConfigureAwait ( false ) ) > 0 )
5754 {
5855 received += read ;
5956
60- if ( received == header . Length )
57+ if ( received == GrpcProtocolConstants . HeaderSize )
6158 {
6259 break ;
6360 }
6461 }
6562
66- if ( received < header . Length )
63+ if ( received < GrpcProtocolConstants . HeaderSize )
6764 {
6865 if ( received == 0 )
6966 {
@@ -73,10 +70,18 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
7370 throw new InvalidDataException ( "Unexpected end of content while reading the message header." ) ;
7471 }
7572
73+ // Read the header first
74+ // - 1 byte flag for compression
75+ // - 4 bytes for the content length
7676 var compressed = ReadCompressedFlag ( header . Span [ 0 ] ) ;
77- var length = BinaryPrimitives . ReadUInt32BigEndian ( header . Span . Slice ( 1 ) ) ;
77+ var length = BinaryPrimitives . ReadUInt32BigEndian ( header . Span . Slice ( 1 , 4 ) ) ;
7878
79- return ( length , compressed ) ;
79+ if ( length > int . MaxValue )
80+ {
81+ throw new InvalidDataException ( "Message too large." ) ;
82+ }
83+
84+ return ( ( int ) length , compressed ) ;
8085 }
8186
8287 public static async ValueTask < TResponse ? > ReadMessageAsync < TResponse > (
@@ -90,17 +95,19 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
9095 CancellationToken cancellationToken )
9196 where TResponse : class
9297 {
98+ byte [ ] ? buffer = null ;
99+
93100 try
94101 {
95102 GrpcCallLog . ReadingMessage ( logger ) ;
96103 cancellationToken . ThrowIfCancellationRequested ( ) ;
97104
98- // Read the header first
99- // - 1 byte flag for compression
100- // - 4 bytes for the content length
101- var header = new byte [ HeaderSize ] ;
105+ // Buffer is used to read header, then message content.
106+ // This size was randomly chosen to hopefully be big enough for many small messages.
107+ // If the message is larger then the array will be replaced when the message size is known.
108+ buffer = ArrayPool < byte > . Shared . Rent ( minimumLength : 4096 ) ;
102109
103- var headerDetails = await ReadHeaderAsync ( responseStream , header , cancellationToken ) . ConfigureAwait ( false ) ;
110+ var headerDetails = await ReadHeaderAsync ( responseStream , buffer , cancellationToken ) . ConfigureAwait ( false ) ;
104111
105112 if ( headerDetails == null )
106113 {
@@ -111,17 +118,22 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
111118 var length = headerDetails . Value . length ;
112119 var compressed = headerDetails . Value . compressed ;
113120
114- if ( length > int . MaxValue )
121+ if ( length > 0 )
115122 {
116- throw new InvalidDataException ( "Message too large." ) ;
117- }
123+ if ( length > maximumMessageSize )
124+ {
125+ throw new RpcException ( ReceivedMessageExceedsLimitStatus ) ;
126+ }
118127
119- if ( length > maximumMessageSize )
120- {
121- throw new RpcException ( ReceivedMessageExceedsLimitStatus ) ;
122- }
128+ // Replace buffer if the message doesn't fit
129+ if ( buffer . Length < length )
130+ {
131+ ArrayPool < byte > . Shared . Return ( buffer ) ;
132+ buffer = ArrayPool < byte > . Shared . Rent ( length ) ;
133+ }
123134
124- var messageData = await ReadMessageContent ( responseStream , length , cancellationToken ) . ConfigureAwait ( false ) ;
135+ await ReadMessageContent ( responseStream , buffer , length , cancellationToken ) . ConfigureAwait ( false ) ;
136+ }
125137
126138 cancellationToken . ThrowIfCancellationRequested ( ) ;
127139
@@ -138,7 +150,7 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
138150 }
139151
140152 // Performance improvement would be to decompress without converting to an intermediary byte array
141- if ( ! TryDecompressMessage ( logger , grpcEncoding , compressionProviders , messageData , out var decompressedMessage ) )
153+ if ( ! TryDecompressMessage ( logger , grpcEncoding , compressionProviders , buffer , length , out var decompressedMessage ) )
142154 {
143155 var supportedEncodings = new List < string > ( ) ;
144156 supportedEncodings . Add ( GrpcProtocolConstants . IdentityGrpcEncoding ) ;
@@ -150,10 +162,10 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
150162 }
151163 else
152164 {
153- payload = new ReadOnlySequence < byte > ( messageData ) ;
165+ payload = new ReadOnlySequence < byte > ( buffer , 0 , length ) ;
154166 }
155167
156- GrpcCallLog . DeserializingMessage ( logger , messageData . Length , typeof ( TResponse ) ) ;
168+ GrpcCallLog . DeserializingMessage ( logger , length , typeof ( TResponse ) ) ;
157169
158170 var deserializationContext = new DefaultDeserializationContext ( ) ;
159171 deserializationContext . SetPayload ( payload ) ;
@@ -164,7 +176,7 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
164176 {
165177 // Check that there is no additional content in the stream for a single message
166178 // There is no ReadByteAsync on stream. Reuse header array with ReadAsync, we don't need it anymore
167- if ( await responseStream . ReadAsync ( header ) . ConfigureAwait ( false ) > 0 )
179+ if ( await responseStream . ReadAsync ( buffer ) . ConfigureAwait ( false ) > 0 )
168180 {
169181 throw new InvalidDataException ( "Unexpected data after finished reading message." ) ;
170182 }
@@ -179,43 +191,44 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
179191 GrpcCallLog . ErrorReadingMessage ( logger , ex ) ;
180192 throw ;
181193 }
194+ finally
195+ {
196+ if ( buffer != null )
197+ {
198+ ArrayPool < byte > . Shared . Return ( buffer ) ;
199+ }
200+ }
182201 }
183202
184- private static async Task < byte [ ] > ReadMessageContent ( Stream responseStream , uint length , CancellationToken cancellationToken )
203+ private static async Task ReadMessageContent ( Stream responseStream , Memory < byte > messageData , int length , CancellationToken cancellationToken )
185204 {
186205 // Read message content until content length is reached
187- byte [ ] messageData ;
188- if ( length > 0 )
206+ var received = 0 ;
207+ int read ;
208+ while ( ( read = await responseStream . ReadAsync ( messageData . Slice ( received , length - received ) , cancellationToken ) . ConfigureAwait ( false ) ) > 0 )
189209 {
190- var received = 0 ;
191- var read = 0 ;
192- messageData = new byte [ length ] ;
193- while ( ( read = await responseStream . ReadAsync ( messageData . AsMemory ( received , messageData . Length - received ) , cancellationToken ) . ConfigureAwait ( false ) ) > 0 )
194- {
195- received += read ;
210+ received += read ;
196211
197- if ( received == messageData . Length )
198- {
199- break ;
200- }
212+ if ( received == length )
213+ {
214+ break ;
201215 }
202216 }
203- else
217+
218+ if ( received < length )
204219 {
205- messageData = Array . Empty < byte > ( ) ;
220+ throw new InvalidDataException ( "Unexpected end of content while reading the message content." ) ;
206221 }
207-
208- return messageData ;
209222 }
210223
211- private static bool TryDecompressMessage ( ILogger logger , string compressionEncoding , Dictionary < string , ICompressionProvider > compressionProviders , byte [ ] messageData , [ NotNullWhen ( true ) ] out ReadOnlySequence < byte > ? result )
224+ private static bool TryDecompressMessage ( ILogger logger , string compressionEncoding , Dictionary < string , ICompressionProvider > compressionProviders , byte [ ] messageData , int length , [ NotNullWhen ( true ) ] out ReadOnlySequence < byte > ? result )
212225 {
213226 if ( compressionProviders . TryGetValue ( compressionEncoding , out var compressionProvider ) )
214227 {
215228 GrpcCallLog . DecompressingMessage ( logger , compressionProvider . EncodingName ) ;
216229
217230 var output = new MemoryStream ( ) ;
218- using ( var compressionStream = compressionProvider . CreateDecompressionStream ( new MemoryStream ( messageData ) ) )
231+ using ( var compressionStream = compressionProvider . CreateDecompressionStream ( new MemoryStream ( messageData , 0 , length , writable : true , publiclyVisible : true ) ) )
219232 {
220233 compressionStream . CopyTo ( output ) ;
221234 }
@@ -244,6 +257,7 @@ private static bool ReadCompressedFlag(byte flag)
244257 }
245258 }
246259
260+ // TODO(JamesNK): Reuse serialization content between message writes. Improve client/duplex streaming allocations.
247261 public static async ValueTask WriteMessageAsync < TMessage > (
248262 this Stream stream ,
249263 ILogger logger ,
@@ -275,8 +289,8 @@ public static async ValueTask WriteMessageAsync<TMessage>(
275289 }
276290
277291 var isCompressed =
278- GrpcProtocolHelpers . CanWriteCompressed ( callOptions . WriteOptions ) &&
279- ! string . Equals ( grpcEncoding , GrpcProtocolConstants . IdentityGrpcEncoding , StringComparison . Ordinal ) ;
292+ GrpcProtocolHelpers . CanWriteCompressed ( callOptions . WriteOptions ) &&
293+ ! string . Equals ( grpcEncoding , GrpcProtocolConstants . IdentityGrpcEncoding , StringComparison . Ordinal ) ;
280294
281295 if ( isCompressed )
282296 {
@@ -288,7 +302,7 @@ public static async ValueTask WriteMessageAsync<TMessage>(
288302 data ) ;
289303 }
290304
291- await WriteHeaderAsync ( stream , data . Length , isCompressed , callOptions . CancellationToken ) . ConfigureAwait ( false ) ;
305+ await stream . WriteAsync ( serializationContext . GetHeader ( isCompressed , data . Length ) , callOptions . CancellationToken ) . ConfigureAwait ( false ) ;
292306 await stream . WriteAsync ( data , callOptions . CancellationToken ) . ConfigureAwait ( false ) ;
293307 await stream . FlushAsync ( callOptions . CancellationToken ) . ConfigureAwait ( false ) ;
294308
@@ -322,25 +336,5 @@ private static ReadOnlyMemory<byte> CompressMessage(ILogger logger, string compr
322336 // Should never reach here
323337 throw new InvalidOperationException ( $ "Could not find compression provider for '{ compressionEncoding } '.") ;
324338 }
325-
326- private static ValueTask WriteHeaderAsync ( Stream stream , int length , bool compress , CancellationToken cancellationToken )
327- {
328- var headerData = new byte [ HeaderSize ] ;
329-
330- // Compression flag
331- headerData [ 0 ] = compress ? ( byte ) 1 : ( byte ) 0 ;
332-
333- // Message length
334- EncodeMessageLength ( length , headerData . AsSpan ( 1 ) ) ;
335-
336- return stream . WriteAsync ( headerData . AsMemory ( 0 , headerData . Length ) , cancellationToken ) ;
337- }
338-
339- private static void EncodeMessageLength ( int messageLength , Span < byte > destination )
340- {
341- Debug . Assert ( destination . Length >= MessageDelimiterSize , "Buffer too small to encode message length." ) ;
342-
343- BinaryPrimitives . WriteUInt32BigEndian ( destination , ( uint ) messageLength ) ;
344- }
345339 }
346340}
0 commit comments