Skip to content

Commit 2bbd480

Browse files
committed
dont re-read the same data
1 parent 0581163 commit 2bbd480

File tree

3 files changed

+71
-5
lines changed

3 files changed

+71
-5
lines changed

src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ public TlsListenerMiddleware(ConnectionDelegate next, Action<ConnectionContext,
2424
internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
2525
{
2626
var input = connection.Transport.Input;
27+
ClientHelloParseState parseState = ClientHelloParseState.NotEnoughData;
2728

2829
while (true)
2930
{
@@ -39,8 +40,7 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
3940
break;
4041
}
4142

42-
var parseState = TryParseClientHello(buffer, out var clientHelloBytes);
43-
43+
parseState = TryParseClientHello(buffer, out var clientHelloBytes);
4444
if (parseState == ClientHelloParseState.NotEnoughData)
4545
{
4646
// if no data will be added, and we still lack enough bytes
@@ -63,7 +63,15 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
6363
}
6464
finally
6565
{
66-
input.AdvanceTo(buffer.Start);
66+
if (parseState is ClientHelloParseState.NotEnoughData)
67+
{
68+
input.AdvanceTo(buffer.Start, buffer.End);
69+
}
70+
else
71+
{
72+
// ready to continue middleware pipeline, reset the buffer to initial state
73+
input.AdvanceTo(buffer.Start);
74+
}
6775
}
6876
}
6977

src/Servers/Kestrel/test/InMemory.FunctionalTests/TestTransport/InMemoryTransportConnection.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public override async ValueTask DisposeAsync()
9191
// This piece of code allows us to wait until the PipeReader has been awaited on.
9292
// We need to wrap lots of layers (including the ValueTask) to gain visiblity into when
9393
// the machinery for the await happens
94-
private class ObservableDuplexPipe : IDuplexPipe
94+
internal class ObservableDuplexPipe : IDuplexPipe
9595
{
9696
private readonly ObservablePipeReader _reader;
9797

@@ -110,11 +110,14 @@ public ObservableDuplexPipe(IDuplexPipe duplexPipe)
110110

111111
public PipeWriter Output { get; }
112112

113+
public int ReadAsyncCounter => _reader.ReadAsyncCounter;
114+
113115
private class ObservablePipeReader : PipeReader
114116
{
115117
private readonly PipeReader _reader;
116118
private readonly TaskCompletionSource _tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
117119

120+
public int ReadAsyncCounter { get; private set; } = 0;
118121
public Task WaitForReadTask => _tcs.Task;
119122

120123
public ObservablePipeReader(PipeReader reader)
@@ -144,6 +147,7 @@ public override void Complete(Exception exception = null)
144147

145148
public override ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = default)
146149
{
150+
ReadAsyncCounter++;
147151
var task = _reader.ReadAsync(cancellationToken);
148152

149153
if (_tcs.Task.IsCompleted)
@@ -152,7 +156,7 @@ public override ValueTask<ReadResult> ReadAsync(CancellationToken cancellationTo
152156
}
153157

154158
return new ValueTask<ReadResult>(new ObservableValueTask<ReadResult>(task, _tcs), 0);
155-
}
159+
}
156160

157161
public override bool TryRead(out ReadResult result)
158162
{

src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerMiddlewareTests.Units.cs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
using Microsoft.Extensions.Hosting;
2626
using Microsoft.Extensions.Logging;
2727
using Moq;
28+
using static Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport.InMemoryTransportConnection;
2829

2930
namespace InMemory.FunctionalTests;
3031

@@ -50,6 +51,59 @@ public Task OnTlsClientHelloAsync_ValidData_MultipleSegments(int id, List<byte[]
5051
public Task OnTlsClientHelloAsync_InvalidData_MultipleSegments(int id, List<byte[]> packets, bool nextMiddlewareInvoked)
5152
=> RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, nextMiddlewareInvoked, tlsClientHelloCallbackExpected: false);
5253

54+
[Fact]
55+
public async Task RunTlsClientHelloCallbackTest_DeterministinglyReads()
56+
{
57+
var serviceContext = new TestServiceContext(LoggerFactory);
58+
var logger = LoggerFactory.CreateLogger<InMemoryTransportConnection>();
59+
var memoryPool = serviceContext.MemoryPoolFactory();
60+
var transportConnection = new InMemoryTransportConnection(memoryPool, logger);
61+
62+
var nextMiddlewareInvoked = false;
63+
var tlsClientHelloCallbackInvoked = false;
64+
65+
var middleware = new TlsListenerMiddleware(
66+
next: ctx =>
67+
{
68+
nextMiddlewareInvoked = true;
69+
var readResult = ctx.Transport.Input.ReadAsync();
70+
Assert.Equal(6, readResult.Result.Buffer.Length);
71+
72+
return Task.CompletedTask;
73+
},
74+
tlsClientHelloBytesCallback: (ctx, data) =>
75+
{
76+
tlsClientHelloCallbackInvoked = true;
77+
}
78+
);
79+
80+
await transportConnection.Input.WriteAsync(new byte[1] { 0x16 });
81+
var middlewareTask = Task.Run(() => middleware.OnTlsClientHelloAsync(transportConnection));
82+
await Task.Delay(TimeSpan.FromMilliseconds(25));
83+
84+
await transportConnection.Input.WriteAsync(new byte[2] { 0x03, 0x01 });
85+
await Task.Delay(TimeSpan.FromMilliseconds(25));
86+
87+
await transportConnection.Input.WriteAsync(new byte[2] { 0x00, 0x20 });
88+
await Task.Delay(TimeSpan.FromMilliseconds(25));
89+
90+
// not correct TLS client hello byte;
91+
// meaning we will not invoke the callback and advance request processing
92+
await transportConnection.Input.WriteAsync(new byte[1] { 0x15 });
93+
await Task.Delay(TimeSpan.FromMilliseconds(25));
94+
95+
await transportConnection.Input.CompleteAsync();
96+
97+
// ensuring that we have read only 5 times (ReadAsync() is called 5 times)
98+
var observableTransport = transportConnection.Transport as ObservableDuplexPipe;
99+
Assert.NotNull(observableTransport);
100+
Assert.Equal(5, observableTransport.ReadAsyncCounter);
101+
102+
await middlewareTask;
103+
Assert.True(nextMiddlewareInvoked);
104+
Assert.False(tlsClientHelloCallbackInvoked);
105+
}
106+
53107
private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments(
54108
int id,
55109
List<byte[]> packets,

0 commit comments

Comments
 (0)