Skip to content

Commit 507f62e

Browse files
committed
refactor to be used as part of HttpsConnectionMiddleware
1 parent 8d4b843 commit 507f62e

File tree

4 files changed

+35
-64
lines changed

4 files changed

+35
-64
lines changed

src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System.Security.Cryptography.X509Certificates;
66
using Microsoft.AspNetCore.Server.Kestrel.Core;
77
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
8-
using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware;
98
using Microsoft.AspNetCore.Server.Kestrel.Https;
109
using Microsoft.AspNetCore.Server.Kestrel.Https.Internal;
1110
using Microsoft.Extensions.DependencyInjection;
@@ -198,15 +197,6 @@ public static ListenOptions UseHttps(this ListenOptions listenOptions, HttpsConn
198197
listenOptions.IsTls = true;
199198
listenOptions.HttpsOptions = httpsOptions;
200199

201-
if (httpsOptions.TlsClientHelloBytesCallback is not null)
202-
{
203-
listenOptions.Use(next =>
204-
{
205-
var middleware = new TlsListenerMiddleware(next, httpsOptions.TlsClientHelloBytesCallback);
206-
return middleware.OnTlsClientHelloAsync;
207-
});
208-
}
209-
210200
listenOptions.Use(next =>
211201
{
212202
var middleware = new HttpsConnectionMiddleware(next, httpsOptions, listenOptions.Protocols, loggerFactory, metrics);

src/Servers/Kestrel/Core/src/Middleware/HttpsConnectionMiddleware.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
1818
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
1919
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
20+
using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware;
2021
using Microsoft.Extensions.Logging;
2122
using Microsoft.Extensions.Logging.Abstractions;
2223

@@ -44,6 +45,9 @@ internal sealed class HttpsConnectionMiddleware
4445
private readonly Func<TlsHandshakeCallbackContext, ValueTask<SslServerAuthenticationOptions>>? _tlsCallbackOptions;
4546
private readonly object? _tlsCallbackOptionsState;
4647

48+
// Captures raw TLS client hello and invokes a user callback if any
49+
private readonly TlsListener? _tlsListener;
50+
4751
// Internal for testing
4852
internal readonly HttpProtocols _httpProtocols;
4953

@@ -112,6 +116,11 @@ public HttpsConnectionMiddleware(ConnectionDelegate next, HttpsConnectionAdapter
112116
(RemoteCertificateValidationCallback?)null : RemoteCertificateValidationCallback;
113117

114118
_sslStreamFactory = s => new SslStream(s, leaveInnerStreamOpen: false, userCertificateValidationCallback: remoteCertificateValidationCallback);
119+
120+
if (options.TlsClientHelloBytesCallback is not null)
121+
{
122+
_tlsListener = new TlsListener(options.TlsClientHelloBytesCallback);
123+
}
115124
}
116125

117126
internal HttpsConnectionMiddleware(
@@ -162,6 +171,10 @@ public async Task OnConnectionAsync(ConnectionContext context)
162171
using var cancellationTokenSource = _ctsPool.Rent();
163172
cancellationTokenSource.CancelAfter(_handshakeTimeout);
164173

174+
if (_tlsListener is not null)
175+
{
176+
await _tlsListener.OnTlsClientHelloAsync(context, cancellationTokenSource.Token);
177+
}
165178
if (_tlsCallbackOptions is null)
166179
{
167180
await DoOptionsBasedHandshakeAsync(context, sslStream, feature, cancellationTokenSource.Token);

src/Servers/Kestrel/Core/src/Middleware/TlsListenerMiddleware.cs renamed to src/Servers/Kestrel/Core/src/Middleware/TlsListener.cs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,26 @@
77

88
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Middleware;
99

10-
internal sealed class TlsListenerMiddleware
10+
internal sealed class TlsListener
1111
{
12-
private readonly ConnectionDelegate _next;
1312
private readonly Action<ConnectionContext, ReadOnlySequence<byte>> _tlsClientHelloBytesCallback;
1413

15-
public TlsListenerMiddleware(ConnectionDelegate next, Action<ConnectionContext, ReadOnlySequence<byte>> tlsClientHelloBytesCallback)
14+
public TlsListener(Action<ConnectionContext, ReadOnlySequence<byte>> tlsClientHelloBytesCallback)
1615
{
17-
_next = next;
1816
_tlsClientHelloBytesCallback = tlsClientHelloBytesCallback;
1917
}
2018

2119
/// <summary>
2220
/// Sniffs the TLS Client Hello message, and invokes a callback if found.
2321
/// </summary>
24-
internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
22+
internal async Task OnTlsClientHelloAsync(ConnectionContext connection, CancellationToken cancellationToken)
2523
{
2624
var input = connection.Transport.Input;
2725
ClientHelloParseState parseState = ClientHelloParseState.NotEnoughData;
2826

2927
while (true)
3028
{
31-
var result = await input.ReadAsync();
29+
var result = await input.ReadAsync(cancellationToken);
3230
var buffer = result.Buffer;
3331

3432
try
@@ -74,8 +72,6 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
7472
}
7573
}
7674
}
77-
78-
await _next(connection);
7975
}
8076

8177
/// <summary>

src/Servers/Kestrel/Core/test/TlsListenerMiddlewareTests.cs

Lines changed: 18 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -66,34 +66,21 @@ public async Task RunTlsClientHelloCallbackTest_DeterministicallyReads()
6666
var transport = new DuplexPipe(reader, writer);
6767
var transportConnection = new DefaultConnectionContext("test", transport, transport);
6868

69-
var nextMiddlewareInvoked = false;
7069
var tlsClientHelloCallbackInvoked = false;
71-
72-
var middleware = new TlsListenerMiddleware(
73-
next: ctx =>
74-
{
75-
nextMiddlewareInvoked = true;
76-
var readResult = ctx.Transport.Input.ReadAsync();
77-
Assert.Equal(5, readResult.Result.Buffer.Length);
78-
79-
return Task.CompletedTask;
80-
},
81-
tlsClientHelloBytesCallback: (ctx, data) =>
82-
{
83-
tlsClientHelloCallbackInvoked = true;
84-
}
85-
);
70+
var middleware = new TlsListener((ctx, data) => { tlsClientHelloCallbackInvoked = true; });
8671

8772
await writer.WriteAsync(new byte[1] { 0x16 });
88-
var middlewareTask = middleware.OnTlsClientHelloAsync(transportConnection);
73+
var middlewareTask = middleware.OnTlsClientHelloAsync(transportConnection, CancellationToken.None);
8974
await writer.WriteAsync(new byte[2] { 0x03, 0x01 });
9075
await writer.WriteAsync(new byte[2] { 0x00, 0x20 });
9176
await writer.CompleteAsync();
9277

9378
await middlewareTask;
94-
Assert.True(nextMiddlewareInvoked);
9579
Assert.False(tlsClientHelloCallbackInvoked);
9680

81+
var readResult = await reader.ReadAsync();
82+
Assert.Equal(5, readResult.Buffer.Length);
83+
9784
// ensuring that we have read limited number of times
9885
Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 4,
9986
$"Expected ReadAsync() to happen about 2-4 times. Actually happened {reader.ReadAsyncCounter} times.");
@@ -110,23 +97,11 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments(
11097
var transport = new DuplexPipe(pipe.Reader, writer);
11198
var transportConnection = new DefaultConnectionContext("test", transport, transport);
11299

113-
var nextMiddlewareInvokedActual = false;
114100
var tlsClientHelloCallbackActual = false;
115101

116102
var fullLength = packets.Sum(p => p.Length);
117103

118-
var middleware = new TlsListenerMiddleware(
119-
next: ctx =>
120-
{
121-
nextMiddlewareInvokedActual = true;
122-
if (tlsClientHelloCallbackActual)
123-
{
124-
var readResult = ctx.Transport.Input.ReadAsync();
125-
Assert.Equal(fullLength, readResult.Result.Buffer.Length);
126-
}
127-
128-
return Task.CompletedTask;
129-
},
104+
var middleware = new TlsListener(
130105
tlsClientHelloBytesCallback: (ctx, data) =>
131106
{
132107
tlsClientHelloCallbackActual = true;
@@ -139,9 +114,8 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments(
139114

140115
// write first packet
141116
await writer.WriteAsync(packets[0]);
142-
var middlewareTask = middleware.OnTlsClientHelloAsync(transportConnection);
117+
var middlewareTask = middleware.OnTlsClientHelloAsync(transportConnection, CancellationToken.None);
143118

144-
145119
/* It is a race condition (middleware's loop and writes here).
146120
* We don't know specifically how many packets will be read by middleware's loop
147121
* (possibly there are even 2 packets - the first and all others combined).
@@ -156,8 +130,13 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments(
156130
await writer.CompleteAsync();
157131
await middlewareTask;
158132

159-
Assert.True(nextMiddlewareInvokedActual);
160133
Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual);
134+
135+
if (tlsClientHelloCallbackActual)
136+
{
137+
var readResult = await pipe.Reader.ReadAsync();
138+
Assert.Equal(fullLength, readResult.Buffer.Length);
139+
}
161140
}
162141

163142
private async Task RunTlsClientHelloCallbackTest(
@@ -171,18 +150,9 @@ private async Task RunTlsClientHelloCallbackTest(
171150
var transport = new DuplexPipe(pipe.Reader, writer);
172151
var transportConnection = new DefaultConnectionContext("test", transport, transport);
173152

174-
var nextMiddlewareInvokedActual = false;
175153
var tlsClientHelloCallbackActual = false;
176154

177-
var middleware = new TlsListenerMiddleware(
178-
next: ctx =>
179-
{
180-
nextMiddlewareInvokedActual = true;
181-
var readResult = ctx.Transport.Input.ReadAsync();
182-
Assert.Equal(packetBytes.Length, readResult.Result.Buffer.Length);
183-
184-
return Task.CompletedTask;
185-
},
155+
var middleware = new TlsListener(
186156
tlsClientHelloBytesCallback: (ctx, data) =>
187157
{
188158
tlsClientHelloCallbackActual = true;
@@ -197,10 +167,12 @@ private async Task RunTlsClientHelloCallbackTest(
197167
await writer.CompleteAsync();
198168

199169
// call middleware and expect a callback
200-
await middleware.OnTlsClientHelloAsync(transportConnection);
170+
await middleware.OnTlsClientHelloAsync(transportConnection, CancellationToken.None);
201171

202-
Assert.True(nextMiddlewareInvokedActual);
203172
Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual);
173+
174+
var readResult = await pipe.Reader.ReadAsync();
175+
Assert.Equal(packetBytes.Length, readResult.Buffer.Length);
204176
}
205177

206178
public static IEnumerable<object[]> ValidClientHelloData()

0 commit comments

Comments
 (0)