Skip to content

Commit 84f6a3e

Browse files
authored
Fix repeatedly decoding base64 with large grpc-web-text request (#1045)
1 parent a76308f commit 84f6a3e

File tree

4 files changed

+268
-85
lines changed

4 files changed

+268
-85
lines changed

src/Grpc.AspNetCore.Web/Internal/Base64PipeReader.cs

Lines changed: 70 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616

1717
#endregion
1818

19-
using Microsoft.AspNetCore.Authorization.Infrastructure;
2019
using System;
2120
using System.Buffers;
2221
using System.Buffers.Text;
23-
using System.Diagnostics;
2422
using System.IO.Pipelines;
2523
using System.Threading;
2624
using System.Threading.Tasks;
@@ -33,9 +31,13 @@ namespace Grpc.AspNetCore.Web.Internal
3331
internal class Base64PipeReader : PipeReader
3432
{
3533
private readonly PipeReader _inner;
36-
private ReadOnlySequence<byte> _currentInnerBuffer;
3734
private ReadOnlySequence<byte> _currentDecodedBuffer;
38-
private byte[]? _rentedBuffer;
35+
private ReadOnlySequence<byte> _currentInnerBuffer;
36+
37+
// Keep track of how much of the inner result has been read in its own field.
38+
// Can't use inner buffer length because an inner buffer could be set but not
39+
// read if it has less than 4 bytes (minimum decode size).
40+
private long _currentInnerBufferRead;
3941

4042
public Base64PipeReader(PipeReader inner)
4143
{
@@ -46,9 +48,9 @@ public override void AdvanceTo(SequencePosition consumed)
4648
{
4749
var consumedPosition = ResolvePosition(consumed);
4850

49-
_inner.AdvanceTo(consumedPosition);
51+
UpdateCurrentBuffers(consumed, consumedPosition);
5052

51-
ReturnBuffer();
53+
_inner.AdvanceTo(consumedPosition);
5254
}
5355

5456
public override void AdvanceTo(SequencePosition consumed, SequencePosition examined)
@@ -61,9 +63,20 @@ public override void AdvanceTo(SequencePosition consumed, SequencePosition exami
6163
? _currentInnerBuffer.End
6264
: ResolvePosition(examined);
6365

66+
UpdateCurrentBuffers(consumed, consumedPosition);
67+
6468
_inner.AdvanceTo(consumedPosition, examinedPosition);
69+
}
70+
71+
private void UpdateCurrentBuffers(SequencePosition consumed, SequencePosition consumedPosition)
72+
{
73+
var lengthBefore = _currentInnerBuffer.Length;
74+
75+
_currentDecodedBuffer = _currentDecodedBuffer.Slice(consumed);
76+
_currentInnerBuffer = _currentInnerBuffer.Slice(consumedPosition);
6577

66-
ReturnBuffer();
78+
// Substract difference in the inner buffer from how much has been decoded.
79+
_currentInnerBufferRead -= lengthBefore - _currentInnerBuffer.Length;
6780
}
6881

6982
private SequencePosition ResolvePosition(SequencePosition base64Position)
@@ -110,20 +123,19 @@ public override void Complete(Exception? exception = null)
110123
{
111124
_inner.Complete(exception);
112125

113-
ReturnBuffer();
114-
}
115-
116-
private void ReturnBuffer()
117-
{
118-
if (_rentedBuffer != null)
119-
{
120-
ArrayPool<byte>.Shared.Return(_rentedBuffer);
121-
_rentedBuffer = null;
122-
}
126+
_currentInnerBuffer = ReadOnlySequence<byte>.Empty;
127+
_currentDecodedBuffer = ReadOnlySequence<byte>.Empty;
123128
}
124129

125130
public async override ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = default)
126131
{
132+
// ReadAsync needs to handle some common situations:
133+
// 1. Base64 requires are least 4 bytes to decode content. If less than 4 bytes are returned
134+
// from the inner reader then repeatedly call the inner reader until 4 bytes are available.
135+
// 2. It is possible that ReadAsync is called many times without consuming the data. We don't
136+
// want to decode the same base64 content over and over. ReadAsync only decodes new content
137+
// and appends it to a sequence.
138+
127139
var innerResult = await _inner.ReadAsync(cancellationToken);
128140
if (innerResult.Buffer.IsEmpty)
129141
{
@@ -133,7 +145,7 @@ public async override ValueTask<ReadResult> ReadAsync(CancellationToken cancella
133145
}
134146

135147
// Minimum valid base64 length is 4. Read until we have at least that much content
136-
while (innerResult.Buffer.Length < 4)
148+
while (innerResult.Buffer.Length - _currentInnerBufferRead < 4)
137149
{
138150
if (innerResult.IsCompleted)
139151
{
@@ -158,31 +170,57 @@ public async override ValueTask<ReadResult> ReadAsync(CancellationToken cancella
158170
}
159171

160172
// Limit result to complete base64 segments (multiples of 4)
161-
var buffer = innerResult.Buffer.Slice(0, (innerResult.Buffer.Length / 4) * 4);
173+
var newResultLength = innerResult.Buffer.Length - _currentInnerBufferRead;
174+
var newResultValidLength = (newResultLength / 4) * 4;
175+
176+
var buffer = innerResult.Buffer.Slice(_currentInnerBufferRead, newResultValidLength);
162177

163178
// The content can contain multiple fragments of base64 content
164179
// Check for padding, and limit returned data to one fragment at a time
165180
var paddingIndex = PositionOf(buffer, (byte)'=');
166181
if (paddingIndex != null)
167182
{
168-
_currentInnerBuffer = buffer.Slice(0, ((paddingIndex.Value / 4) + 1) * 4);
169-
}
170-
else
171-
{
172-
_currentInnerBuffer = buffer;
183+
buffer = buffer.Slice(0, ((paddingIndex.Value / 4) + 1) * 4);
173184
}
174185

175-
var length = (int)_currentInnerBuffer.Length;
176-
// Any rented buffer should have been returned
177-
Debug.Assert(_rentedBuffer == null);
178-
_rentedBuffer = ArrayPool<byte>.Shared.Rent(length);
179-
_currentInnerBuffer.CopyTo(_rentedBuffer);
186+
// Copy the buffer data to a new array.
187+
// Need a copy that we own because it will be decoded in place.
188+
var decodedBuffer = buffer.ToArray();
180189

181-
var validLength = (length / 4) * 4;
182-
var status = Base64.DecodeFromUtf8InPlace(_rentedBuffer.AsSpan(0, validLength), out var bytesWritten);
190+
var status = Base64.DecodeFromUtf8InPlace(decodedBuffer, out var bytesWritten);
183191
if (status == OperationStatus.Done || status == OperationStatus.NeedMoreData)
184192
{
185-
_currentDecodedBuffer = new ReadOnlySequence<byte>(_rentedBuffer, 0, bytesWritten);
193+
_currentInnerBuffer = innerResult.Buffer.Slice(0, _currentInnerBufferRead + decodedBuffer.Length);
194+
195+
_currentInnerBufferRead = _currentInnerBuffer.Length;
196+
197+
// Update decoded buffer. If there have been multiple reads with the same content then
198+
// newly decoded content will be appended to the sequence.
199+
if (_currentDecodedBuffer.IsEmpty)
200+
{
201+
// Avoid creating segments for single segment sequence.
202+
_currentDecodedBuffer = new ReadOnlySequence<byte>(decodedBuffer, 0, bytesWritten);
203+
}
204+
else if (_currentDecodedBuffer.IsSingleSegment)
205+
{
206+
var start = new MemorySegment<byte>(_currentDecodedBuffer.First);
207+
208+
// Append new content to end.
209+
var end = start.Append(decodedBuffer.AsMemory(0, bytesWritten));
210+
211+
_currentDecodedBuffer = new ReadOnlySequence<byte>(start, 0, end, end.Memory.Length);
212+
}
213+
else
214+
{
215+
var start = (MemorySegment<byte>)_currentDecodedBuffer.Start.GetObject()!;
216+
var end = (MemorySegment<byte>)_currentDecodedBuffer.End.GetObject()!;
217+
218+
// Append new content to end.
219+
end = end.Append(decodedBuffer.AsMemory(0, bytesWritten));
220+
221+
_currentDecodedBuffer = new ReadOnlySequence<byte>(start, 0, end, end.Memory.Length);
222+
}
223+
186224
return new ReadResult(
187225
_currentDecodedBuffer,
188226
innerResult.IsCanceled,
@@ -236,50 +274,5 @@ public override bool TryRead(out ReadResult result)
236274

237275
return null;
238276
}
239-
240-
private static bool ValidatePadding(in ReadOnlySequence<byte> source)
241-
{
242-
if (source.IsSingleSegment)
243-
{
244-
return ValidatePaddingCore(source.First.Span);
245-
}
246-
else
247-
{
248-
return ValidatePaddingMultiSegment(source);
249-
}
250-
}
251-
252-
private static bool ValidatePaddingMultiSegment(in ReadOnlySequence<byte> source)
253-
{
254-
var position = source.Start;
255-
256-
while (source.TryGet(ref position, out ReadOnlyMemory<byte> memory))
257-
{
258-
if (!ValidatePaddingCore(memory.Span))
259-
{
260-
return false;
261-
}
262-
263-
if (position.GetObject() == null)
264-
{
265-
break;
266-
}
267-
}
268-
269-
return true;
270-
}
271-
272-
private static bool ValidatePaddingCore(ReadOnlySpan<byte> span)
273-
{
274-
for (var i = 0; i < span.Length; i++)
275-
{
276-
if (span[i] != '=')
277-
{
278-
return false;
279-
}
280-
}
281-
282-
return true;
283-
}
284277
}
285278
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#region Copyright notice and license
2+
3+
// Copyright 2019 The gRPC Authors
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
#endregion
18+
19+
using System;
20+
using System.Buffers;
21+
22+
namespace Grpc.AspNetCore.Web.Internal
23+
{
24+
internal class MemorySegment<T> : ReadOnlySequenceSegment<T>
25+
{
26+
public MemorySegment(ReadOnlyMemory<T> memory)
27+
{
28+
Memory = memory;
29+
}
30+
31+
public MemorySegment<T> Append(ReadOnlyMemory<T> memory)
32+
{
33+
var segment = new MemorySegment<T>(memory)
34+
{
35+
RunningIndex = RunningIndex + Memory.Length
36+
};
37+
38+
Next = segment;
39+
40+
return segment;
41+
}
42+
}
43+
}

test/FunctionalTests/TestServer/TesterServiceTests.cs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,17 @@ public async Task ClientStreamingTest_Error()
9191
// Act
9292
using var call = client.SayHelloClientStreamingError();
9393

94-
var ex = await ExceptionAssert.ThrowsAsync<InvalidOperationException>(async () =>
94+
var ex = await ExceptionAssert.ThrowsAsync<Exception>(async () =>
9595
{
96-
var names = new[] { "James", "Jo", "Lee" };
97-
96+
while (true)
9897
{
99-
foreach (var name in names)
100-
{
101-
await call.RequestStream.WriteAsync(new HelloRequest { Name = name }).DefaultTimeout();
102-
await Task.Delay(50);
103-
}
98+
await call.RequestStream.WriteAsync(new HelloRequest { Name = "Name!" }).DefaultTimeout();
99+
await Task.Delay(50);
104100
}
105101
}).DefaultTimeout();
106102

107103
// Assert
104+
Assert.IsTrue(ex is InvalidOperationException || ex is RpcException);
108105
Assert.AreEqual(StatusCode.NotFound, call.GetStatus().StatusCode);
109106
}
110107

0 commit comments

Comments
 (0)