Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions src/Servers/Kestrel/Core/src/ListenOptionsHttpsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.Security.Cryptography.X509Certificates;
using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware;
using Microsoft.AspNetCore.Server.Kestrel.Https;
using Microsoft.AspNetCore.Server.Kestrel.Https.Internal;
using Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -198,15 +197,6 @@ public static ListenOptions UseHttps(this ListenOptions listenOptions, HttpsConn
listenOptions.IsTls = true;
listenOptions.HttpsOptions = httpsOptions;

if (httpsOptions.TlsClientHelloBytesCallback is not null)
{
listenOptions.Use(next =>
{
var middleware = new TlsListenerMiddleware(next, httpsOptions.TlsClientHelloBytesCallback);
return middleware.OnTlsClientHelloAsync;
});
}

listenOptions.Use(next =>
{
var middleware = new HttpsConnectionMiddleware(next, httpsOptions, listenOptions.Protocols, loggerFactory, metrics);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Server.Kestrel.Core.Middleware;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

Expand Down Expand Up @@ -44,6 +45,9 @@ internal sealed class HttpsConnectionMiddleware
private readonly Func<TlsHandshakeCallbackContext, ValueTask<SslServerAuthenticationOptions>>? _tlsCallbackOptions;
private readonly object? _tlsCallbackOptionsState;

// Captures raw TLS client hello and invokes a user callback if any
private readonly TlsListener? _tlsListener;

// Internal for testing
internal readonly HttpProtocols _httpProtocols;

Expand Down Expand Up @@ -112,6 +116,11 @@ public HttpsConnectionMiddleware(ConnectionDelegate next, HttpsConnectionAdapter
(RemoteCertificateValidationCallback?)null : RemoteCertificateValidationCallback;

_sslStreamFactory = s => new SslStream(s, leaveInnerStreamOpen: false, userCertificateValidationCallback: remoteCertificateValidationCallback);

if (options.TlsClientHelloBytesCallback is not null)
{
_tlsListener = new TlsListener(options.TlsClientHelloBytesCallback);
}
}

internal HttpsConnectionMiddleware(
Expand Down Expand Up @@ -162,6 +171,10 @@ public async Task OnConnectionAsync(ConnectionContext context)
using var cancellationTokenSource = _ctsPool.Rent();
cancellationTokenSource.CancelAfter(_handshakeTimeout);

if (_tlsListener is not null)
{
await _tlsListener.OnTlsClientHelloAsync(context, cancellationTokenSource.Token);
}
if (_tlsCallbackOptions is null)
{
await DoOptionsBasedHandshakeAsync(context, sslStream, feature, cancellationTokenSource.Token);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,27 @@

namespace Microsoft.AspNetCore.Server.Kestrel.Core.Middleware;

internal sealed class TlsListenerMiddleware
internal sealed class TlsListener
{
private readonly ConnectionDelegate _next;
private readonly Action<ConnectionContext, ReadOnlySequence<byte>> _tlsClientHelloBytesCallback;

public TlsListenerMiddleware(ConnectionDelegate next, Action<ConnectionContext, ReadOnlySequence<byte>> tlsClientHelloBytesCallback)
public TlsListener(Action<ConnectionContext, ReadOnlySequence<byte>> tlsClientHelloBytesCallback)
{
_next = next;
_tlsClientHelloBytesCallback = tlsClientHelloBytesCallback;
}

/// <summary>
/// Sniffs the TLS Client Hello message, and invokes a callback if found.
/// </summary>
internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
internal async Task OnTlsClientHelloAsync(ConnectionContext connection, CancellationToken cancellationToken)
{
var input = connection.Transport.Input;
ClientHelloParseState parseState = ClientHelloParseState.NotEnoughData;
short recordLength = -1; // remembers the length of TLS record to not re-parse header on every iteration

while (true)
{
var result = await input.ReadAsync();
var result = await input.ReadAsync(cancellationToken);
var buffer = result.Buffer;

try
Expand All @@ -40,7 +39,7 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
break;
}

parseState = TryParseClientHello(buffer, out var clientHelloBytes);
parseState = TryParseClientHello(buffer, ref recordLength, out var clientHelloBytes);
if (parseState == ClientHelloParseState.NotEnoughData)
{
// if no data will be added, and we still lack enough bytes
Expand Down Expand Up @@ -74,8 +73,6 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
}
}
}

await _next(connection);
}

/// <summary>
Expand All @@ -85,10 +82,25 @@ internal async Task OnTlsClientHelloAsync(ConnectionContext connection)
/// TLS 1.2: https://datatracker.ietf.org/doc/html/rfc5246#section-6.2
/// TLS 1.3: https://datatracker.ietf.org/doc/html/rfc8446#section-5.1
/// </summary>
private static ClientHelloParseState TryParseClientHello(ReadOnlySequence<byte> buffer, out ReadOnlySequence<byte> clientHelloBytes)
private static ClientHelloParseState TryParseClientHello(ReadOnlySequence<byte> buffer, ref short recordLength, out ReadOnlySequence<byte> clientHelloBytes)
{
clientHelloBytes = default;

// in case bad actor will be sending a TLS client hello one byte at a time
// and we know the expected length of TLS client hello,
// we can check and fail fastly here instead of re-parsing the TLS client hello "header" on each iteration
if (recordLength != -1 && buffer.Length < 5 + recordLength)
{
return ClientHelloParseState.NotEnoughData;
}

// this means we finally got a full tls record, so we can return without parsing again
if (recordLength != -1)
{
clientHelloBytes = buffer.Slice(0, 5 + recordLength);
return ClientHelloParseState.ValidTlsClientHello;
}

if (buffer.Length < 6)
{
return ClientHelloParseState.NotEnoughData;
Expand All @@ -109,7 +121,7 @@ private static ClientHelloParseState TryParseClientHello(ReadOnlySequence<byte>
}

// Record length
if (!reader.TryReadBigEndian(out short recordLength))
if (!reader.TryReadBigEndian(out recordLength))
{
return ClientHelloParseState.NotTlsClientHello;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests;

public class TlsListenerMiddlewareTests
public class TlsListenerTests
{
[Theory]
[MemberData(nameof(ValidClientHelloData))]
Expand All @@ -50,6 +50,57 @@ public Task OnTlsClientHelloAsync_ValidData_MultipleSegments(int id, List<byte[]
public Task OnTlsClientHelloAsync_InvalidData_MultipleSegments(int id, List<byte[]> packets)
=> RunTlsClientHelloCallbackTest_WithMultipleSegments(id, packets, tlsClientHelloCallbackExpected: false);

[Fact]
public async Task RunTlsClientHelloCallbackTest_WithPreCancelledToken()
{
var serviceContext = new TestServiceContext();

var pipe = new Pipe();
var writer = pipe.Writer;
var reader = new ObservablePipeReader(pipe.Reader);

var transport = new DuplexPipe(reader, writer);
var transportConnection = new DefaultConnectionContext("test", transport, transport);

var tlsClientHelloCallbackInvoked = false;
var listener = new TlsListener((ctx, data) => { tlsClientHelloCallbackInvoked = true; });

var cts = new CancellationTokenSource();
cts.Cancel();

await writer.WriteAsync(new byte[1] { 0x16 });
await Assert.ThrowsAsync<OperationCanceledException>(async () =>
{
await listener.OnTlsClientHelloAsync(transportConnection, cts.Token);
});
Assert.False(tlsClientHelloCallbackInvoked);
}

[Fact]
public async Task RunTlsClientHelloCallbackTest_WithPendingCancellation()
{
var serviceContext = new TestServiceContext();

var pipe = new Pipe();
var writer = pipe.Writer;
var reader = new ObservablePipeReader(pipe.Reader);

var transport = new DuplexPipe(reader, writer);
var transportConnection = new DefaultConnectionContext("test", transport, transport);

var tlsClientHelloCallbackInvoked = false;
var listener = new TlsListener((ctx, data) => { tlsClientHelloCallbackInvoked = true; });

var cts = new CancellationTokenSource();
await writer.WriteAsync(new byte[1] { 0x16 });
var listenerTask = listener.OnTlsClientHelloAsync(transportConnection, cts.Token);
await writer.WriteAsync(new byte[2] { 0x03, 0x01 });
cts.Cancel();

await Assert.ThrowsAsync<OperationCanceledException>(async () => await listenerTask);
Assert.False(tlsClientHelloCallbackInvoked);
}

[Fact]
public async Task RunTlsClientHelloCallbackTest_DeterministicallyReads()
{
Expand All @@ -66,34 +117,21 @@ public async Task RunTlsClientHelloCallbackTest_DeterministicallyReads()
var transport = new DuplexPipe(reader, writer);
var transportConnection = new DefaultConnectionContext("test", transport, transport);

var nextMiddlewareInvoked = false;
var tlsClientHelloCallbackInvoked = false;

var middleware = new TlsListenerMiddleware(
next: ctx =>
{
nextMiddlewareInvoked = true;
var readResult = ctx.Transport.Input.ReadAsync();
Assert.Equal(5, readResult.Result.Buffer.Length);

return Task.CompletedTask;
},
tlsClientHelloBytesCallback: (ctx, data) =>
{
tlsClientHelloCallbackInvoked = true;
}
);
var listener = new TlsListener((ctx, data) => { tlsClientHelloCallbackInvoked = true; });

await writer.WriteAsync(new byte[1] { 0x16 });
var middlewareTask = middleware.OnTlsClientHelloAsync(transportConnection);
var listenerTask = listener.OnTlsClientHelloAsync(transportConnection, CancellationToken.None);
await writer.WriteAsync(new byte[2] { 0x03, 0x01 });
await writer.WriteAsync(new byte[2] { 0x00, 0x20 });
await writer.CompleteAsync();

await middlewareTask;
Assert.True(nextMiddlewareInvoked);
await listenerTask;
Assert.False(tlsClientHelloCallbackInvoked);

var readResult = await reader.ReadAsync();
Assert.Equal(5, readResult.Buffer.Length);

// ensuring that we have read limited number of times
Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 4,
$"Expected ReadAsync() to happen about 2-4 times. Actually happened {reader.ReadAsyncCounter} times.");
Expand All @@ -110,23 +148,11 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments(
var transport = new DuplexPipe(pipe.Reader, writer);
var transportConnection = new DefaultConnectionContext("test", transport, transport);

var nextMiddlewareInvokedActual = false;
var tlsClientHelloCallbackActual = false;

var fullLength = packets.Sum(p => p.Length);

var middleware = new TlsListenerMiddleware(
next: ctx =>
{
nextMiddlewareInvokedActual = true;
if (tlsClientHelloCallbackActual)
{
var readResult = ctx.Transport.Input.ReadAsync();
Assert.Equal(fullLength, readResult.Result.Buffer.Length);
}

return Task.CompletedTask;
},
var listener = new TlsListener(
tlsClientHelloBytesCallback: (ctx, data) =>
{
tlsClientHelloCallbackActual = true;
Expand All @@ -139,9 +165,8 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments(

// write first packet
await writer.WriteAsync(packets[0]);
var middlewareTask = middleware.OnTlsClientHelloAsync(transportConnection);
var listenerTask = listener.OnTlsClientHelloAsync(transportConnection, CancellationToken.None);


/* It is a race condition (middleware's loop and writes here).
* We don't know specifically how many packets will be read by middleware's loop
* (possibly there are even 2 packets - the first and all others combined).
Expand All @@ -154,10 +179,15 @@ private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments(
await writer.WriteAsync(packet);
}
await writer.CompleteAsync();
await middlewareTask;
await listenerTask;

Assert.True(nextMiddlewareInvokedActual);
Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual);

if (tlsClientHelloCallbackActual)
{
var readResult = await pipe.Reader.ReadAsync();
Assert.Equal(fullLength, readResult.Buffer.Length);
}
}

private async Task RunTlsClientHelloCallbackTest(
Expand All @@ -171,18 +201,9 @@ private async Task RunTlsClientHelloCallbackTest(
var transport = new DuplexPipe(pipe.Reader, writer);
var transportConnection = new DefaultConnectionContext("test", transport, transport);

var nextMiddlewareInvokedActual = false;
var tlsClientHelloCallbackActual = false;

var middleware = new TlsListenerMiddleware(
next: ctx =>
{
nextMiddlewareInvokedActual = true;
var readResult = ctx.Transport.Input.ReadAsync();
Assert.Equal(packetBytes.Length, readResult.Result.Buffer.Length);

return Task.CompletedTask;
},
var listener = new TlsListener(
tlsClientHelloBytesCallback: (ctx, data) =>
{
tlsClientHelloCallbackActual = true;
Expand All @@ -197,10 +218,12 @@ private async Task RunTlsClientHelloCallbackTest(
await writer.CompleteAsync();

// call middleware and expect a callback
await middleware.OnTlsClientHelloAsync(transportConnection);
await listener.OnTlsClientHelloAsync(transportConnection, CancellationToken.None);

Assert.True(nextMiddlewareInvokedActual);
Assert.Equal(tlsClientHelloCallbackExpected, tlsClientHelloCallbackActual);

var readResult = await pipe.Reader.ReadAsync();
Assert.Equal(packetBytes.Length, readResult.Buffer.Length);
}

public static IEnumerable<object[]> ValidClientHelloData()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace InMemory.FunctionalTests;

public class TlsListenerMiddlewareTests : TestApplicationErrorLoggerLoggedTest
public class TlsListenerTests : TestApplicationErrorLoggerLoggedTest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably add a cancellation test here as well. Set the handshake timeout to something small like 1 millisecond and check that the request was canceled.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

{
private static readonly X509Certificate2 _x509Certificate2 = TestResources.GetTestCertificate();

Expand Down
Loading