Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 5 additions & 33 deletions src/Servers/Kestrel/Core/test/TlsListenerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ public async Task RunTlsClientHelloCallbackTest_WithExtraShortLastingToken()
var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(3));

await writer.WriteAsync(new byte[1] { 0x16 });
await VerifyThrowsAnyAsync(
async () => await listener.OnTlsClientHelloAsync(transportConnection, cts.Token),
typeof(OperationCanceledException), typeof(TaskCanceledException));
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => listener.OnTlsClientHelloAsync(transportConnection, cts.Token));
Assert.False(tlsClientHelloCallbackInvoked);
}

Expand All @@ -95,9 +93,7 @@ public async Task RunTlsClientHelloCallbackTest_WithPreCanceledToken()
cts.Cancel();

await writer.WriteAsync(new byte[1] { 0x16 });
await VerifyThrowsAnyAsync(
async () => await listener.OnTlsClientHelloAsync(transportConnection, cts.Token),
typeof(OperationCanceledException), typeof(TaskCanceledException));
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => listener.OnTlsClientHelloAsync(transportConnection, cts.Token));
Assert.False(tlsClientHelloCallbackInvoked);
}

Expand All @@ -122,7 +118,7 @@ public async Task RunTlsClientHelloCallbackTest_WithPendingCancellation()
await writer.WriteAsync(new byte[2] { 0x03, 0x01 });
cts.Cancel();

await Assert.ThrowsAsync<OperationCanceledException>(async () => await listenerTask);
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => listenerTask);
Assert.False(tlsClientHelloCallbackInvoked);
}

Expand Down Expand Up @@ -158,8 +154,8 @@ public async Task RunTlsClientHelloCallbackTest_DeterministicallyReads()
Assert.Equal(5, readResult.Buffer.Length);

// ensuring that we have read limited number of times
Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 4,
$"Expected ReadAsync() to happen about 2-4 times. Actually happened {reader.ReadAsyncCounter} times.");
Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 5,
$"Expected ReadAsync() to happen about 2-5 times. Actually happened {reader.ReadAsyncCounter} times.");
}

private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments(
Expand Down Expand Up @@ -623,28 +619,4 @@ public static IEnumerable<object[]> InvalidClientHelloData_Segmented()
_invalidTlsClientHelloHeader, _invalid3BytesMessage, _invalid9BytesMessage,
_invalidUnknownProtocolVersion1, _invalidUnknownProtocolVersion2, _invalidIncorrectHandshakeMessageType
};

static async Task VerifyThrowsAnyAsync(Func<Task> code, params Type[] exceptionTypes)
{
if (exceptionTypes == null || exceptionTypes.Length == 0)
{
throw new ArgumentException("At least one exception type must be provided.", nameof(exceptionTypes));
}

try
{
await code();
}
catch (Exception ex)
{
if (exceptionTypes.Any(type => type.IsInstanceOfType(ex)))
{
return;
}

throw ThrowsException.ForIncorrectExceptionType(exceptionTypes.First(), ex);
}

throw ThrowsException.ForNoException(exceptionTypes.First());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json.Linq;
using Xunit.Sdk;

namespace InMemory.FunctionalTests;

Expand Down Expand Up @@ -66,4 +68,98 @@ await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions

Assert.True(tlsClientHelloCallbackInvoked);
}

[Fact]
public async Task TlsClientHelloBytesCallback_PreCanceledToken()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There still isn't a test using the HandshakeTimeout

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

{
var tlsClientHelloCallbackInvoked = false;

var testContext = new TestServiceContext(LoggerFactory);
await using (var server = new TestServer(context => Task.CompletedTask,
testContext,
listenOptions =>
{
listenOptions.UseHttps(_x509Certificate2, httpsOptions =>
{
httpsOptions.TlsClientHelloBytesCallback = (connection, clientHelloBytes) =>
{
Logger.LogDebug("[Received TlsClientHelloBytesCallback] Connection: {0}; TLS client hello buffer: {1}", connection.ConnectionId, clientHelloBytes.Length);
tlsClientHelloCallbackInvoked = true;
Assert.True(clientHelloBytes.Length > 32);
Assert.NotNull(connection);
};
});
}))
{
using (var connection = server.CreateConnection())
{
using (var sslStream = new SslStream(connection.Stream, false, (sender, cert, chain, errors) => true, null))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these tests should actually send any bytes over the connection. That creates a race where the timeout doesn't actually occur by the time the tls client hello is received and parsed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HttpsConnectionMiddleware does catch the exception, so i am only seeing if the connection is closed after the timeout. I changed to

await connection.TransportConnection.Input.WriteAsync(new byte[] { 0x16 });
var readResult = await connection.TransportConnection.Output.ReadAsync();

// HttpsConnectionMiddleware catches the exception, so we can only check the effects of the timeout here
Assert.True(readResult.IsCompleted);

{
var cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromMilliseconds(1));
var token = cancellationTokenSource.Token;

await Assert.ThrowsAnyAsync<OperationCanceledException>(() => sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions
{
TargetHost = "localhost",
EnabledSslProtocols = SslProtocols.None
}, token));
}
}
}

Assert.False(tlsClientHelloCallbackInvoked);
}

[Fact]
public async Task TlsClientHelloBytesCallback_UsesOptionsTimeout()
{
var testContext = new TestServiceContext(LoggerFactory);
await using (var server = new TestServer(context => Task.CompletedTask,
testContext,
listenOptions =>
{
listenOptions.UseHttps(_x509Certificate2, httpsOptions =>
{
httpsOptions.HandshakeTimeout = TimeSpan.FromMilliseconds(1);

httpsOptions.TlsClientHelloBytesCallback = (connection, clientHelloBytes) =>
{
Logger.LogDebug("[Received TlsClientHelloBytesCallback] Connection: {0}; TLS client hello buffer: {1}", connection.ConnectionId, clientHelloBytes.Length);
Assert.True(clientHelloBytes.Length > 32);
Assert.NotNull(connection);
};
});
}))
{
using (var connection = server.CreateConnection())
{
using (var sslStream = new SslStream(connection.Stream, false, (sender, cert, chain, errors) => true, null))
{
try
{
await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions
{
TargetHost = "localhost",
EnabledSslProtocols = SslProtocols.None
});

var request = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n");
await sslStream.WriteAsync(request, 0, request.Length);
await sslStream.ReadAsync(new Memory<byte>(new byte[1024]));
}
catch (Exception ex)
when (ex is OperationCanceledException or TaskCanceledException // when cancellation comes from tls listener
or IOException // when the underlying stream is closed due to timeout
)
{
// expected
}
catch (Exception ex)
{
ThrowsException.ForIncorrectExceptionType(typeof(OperationCanceledException), ex);
}
}
}
}
}
}
Loading