Skip to content

Commit ae1d9cb

Browse files
committed
Merge in 'release/6.0' changes
2 parents d530c7e + 8a8f51d commit ae1d9cb

File tree

4 files changed

+227
-2
lines changed

4 files changed

+227
-2
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Net.WebSockets;
5+
using Microsoft.AspNetCore.Http;
6+
7+
namespace Microsoft.AspNetCore.WebSockets
8+
{
9+
/// <summary>
10+
/// Used in ASP.NET Core to wrap a WebSocket with its associated HttpContext so that when the WebSocket is aborted
11+
/// the underlying HttpContext is aborted. All other methods are delegated to the underlying WebSocket.
12+
/// </summary>
13+
internal sealed class ServerWebSocket : WebSocket
14+
{
15+
private readonly WebSocket _wrappedSocket;
16+
private readonly HttpContext _context;
17+
18+
internal ServerWebSocket(WebSocket wrappedSocket, HttpContext context)
19+
{
20+
ArgumentNullException.ThrowIfNull(wrappedSocket);
21+
ArgumentNullException.ThrowIfNull(context);
22+
23+
_wrappedSocket = wrappedSocket;
24+
_context = context;
25+
}
26+
27+
public override WebSocketCloseStatus? CloseStatus => _wrappedSocket.CloseStatus;
28+
29+
public override string? CloseStatusDescription => _wrappedSocket.CloseStatusDescription;
30+
31+
public override WebSocketState State => _wrappedSocket.State;
32+
33+
public override string? SubProtocol => _wrappedSocket.SubProtocol;
34+
35+
public override void Abort()
36+
{
37+
_wrappedSocket.Abort();
38+
_context.Abort();
39+
}
40+
41+
public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
42+
{
43+
return _wrappedSocket.CloseAsync(closeStatus, statusDescription, cancellationToken);
44+
}
45+
46+
public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken)
47+
{
48+
return _wrappedSocket.CloseOutputAsync(closeStatus, statusDescription, cancellationToken);
49+
}
50+
51+
public override void Dispose()
52+
{
53+
_wrappedSocket.Dispose();
54+
}
55+
56+
public override Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
57+
{
58+
return _wrappedSocket.ReceiveAsync(buffer, cancellationToken);
59+
}
60+
61+
public override ValueTask<ValueWebSocketReceiveResult> ReceiveAsync(Memory<byte> buffer, CancellationToken cancellationToken)
62+
{
63+
return _wrappedSocket.ReceiveAsync(buffer, cancellationToken);
64+
}
65+
66+
public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
67+
{
68+
return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
69+
}
70+
71+
public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, WebSocketMessageFlags messageFlags, CancellationToken cancellationToken)
72+
{
73+
return _wrappedSocket.SendAsync(buffer, messageType, messageFlags, cancellationToken);
74+
}
75+
76+
public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
77+
{
78+
return _wrappedSocket.SendAsync(buffer, messageType, endOfMessage, cancellationToken);
79+
}
80+
}
81+
}
82+

src/Middleware/WebSockets/src/WebSocketMiddleware.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,15 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
194194

195195
Stream opaqueTransport = await _upgradeFeature.UpgradeAsync(); // Sets status code to 101
196196

197-
return WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
197+
var wrappedSocket = WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
198198
{
199199
IsServer = true,
200200
KeepAliveInterval = keepAliveInterval,
201201
SubProtocol = subProtocol,
202202
DangerousDeflateOptions = deflateOptions
203203
});
204+
205+
return new ServerWebSocket(wrappedSocket, _context);
204206
}
205207

206208
public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)

src/Middleware/WebSockets/test/UnitTests/WebSocketMiddlewareTests.cs

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Net.Http;
77
using System.Net.WebSockets;
88
using System.Text;
9+
using Microsoft.AspNetCore.Connections;
910
using System.Threading;
1011
using System.Threading.Tasks;
1112
using Microsoft.AspNetCore.Testing;
@@ -499,6 +500,146 @@ public async Task CloseFromCloseReceived_Success()
499500
}
500501
}
501502

503+
[Fact]
504+
public async Task WebSocket_Abort_Interrupts_Pending_ReceiveAsync()
505+
{
506+
WebSocket serverSocket = null;
507+
508+
// Events that we want to sequence execution across client and server.
509+
var socketWasAccepted = new ManualResetEventSlim();
510+
var socketWasAborted = new ManualResetEventSlim();
511+
var firstReceiveOccured = new ManualResetEventSlim();
512+
var secondReceiveInitiated = new ManualResetEventSlim();
513+
514+
Exception receiveException = null;
515+
516+
await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
517+
{
518+
Assert.True(context.WebSockets.IsWebSocketRequest);
519+
serverSocket = await context.WebSockets.AcceptWebSocketAsync();
520+
socketWasAccepted.Set();
521+
522+
var serverBuffer = new byte[1024];
523+
524+
try
525+
{
526+
while (serverSocket.State is WebSocketState.Open or WebSocketState.CloseSent)
527+
{
528+
if (firstReceiveOccured.IsSet)
529+
{
530+
var pendingResponse = serverSocket.ReceiveAsync(serverBuffer, default);
531+
secondReceiveInitiated.Set();
532+
var response = await pendingResponse;
533+
}
534+
else
535+
{
536+
var response = await serverSocket.ReceiveAsync(serverBuffer, default);
537+
firstReceiveOccured.Set();
538+
}
539+
}
540+
}
541+
catch (ConnectionAbortedException ex)
542+
{
543+
socketWasAborted.Set();
544+
receiveException = ex;
545+
}
546+
catch (Exception ex)
547+
{
548+
// Capture this exception so a test failure can give us more information.
549+
receiveException = ex;
550+
}
551+
finally
552+
{
553+
Assert.IsType<ConnectionAbortedException>(receiveException);
554+
}
555+
}))
556+
{
557+
var clientBuffer = new byte[1024];
558+
559+
using (var client = new ClientWebSocket())
560+
{
561+
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);
562+
563+
var socketWasAcceptedDidNotTimeout = socketWasAccepted.Wait(10000);
564+
Assert.True(socketWasAcceptedDidNotTimeout, "Socket was not accepted within the allotted time.");
565+
566+
await client.SendAsync(clientBuffer, WebSocketMessageType.Binary, false, default);
567+
568+
var firstReceiveOccuredDidNotTimeout = firstReceiveOccured.Wait(10000);
569+
Assert.True(firstReceiveOccuredDidNotTimeout, "First receive did not occur within the allotted time.");
570+
571+
var secondReceiveInitiatedDidNotTimeout = secondReceiveInitiated.Wait(10000);
572+
Assert.True(secondReceiveInitiatedDidNotTimeout, "Second receive was not initiated within the allotted time.");
573+
574+
serverSocket.Abort();
575+
576+
var socketWasAbortedDidNotTimeout = socketWasAborted.Wait(1000); // Give it a second to process the abort.
577+
Assert.True(socketWasAbortedDidNotTimeout, "Abort did not occur within the allotted time.");
578+
}
579+
}
580+
}
581+
582+
[Fact]
583+
public async Task WebSocket_AllowsCancelling_Pending_ReceiveAsync_When_CancellationTokenProvided()
584+
{
585+
WebSocket serverSocket = null;
586+
CancellationTokenSource cts = new CancellationTokenSource();
587+
588+
var socketWasAccepted = new ManualResetEventSlim();
589+
var operationWasCancelled = new ManualResetEventSlim();
590+
var firstReceiveOccured = new ManualResetEventSlim();
591+
592+
await using (var server = KestrelWebSocketHelpers.CreateServer(LoggerFactory, out var port, async context =>
593+
{
594+
Assert.True(context.WebSockets.IsWebSocketRequest);
595+
serverSocket = await context.WebSockets.AcceptWebSocketAsync();
596+
socketWasAccepted.Set();
597+
598+
var serverBuffer = new byte[1024];
599+
600+
var finishedWithOperationCancelled = false;
601+
602+
try
603+
{
604+
while (serverSocket.State is WebSocketState.Open or WebSocketState.CloseSent)
605+
{
606+
var response = await serverSocket.ReceiveAsync(serverBuffer, cts.Token);
607+
firstReceiveOccured.Set();
608+
}
609+
}
610+
catch (OperationCanceledException)
611+
{
612+
operationWasCancelled.Set();
613+
finishedWithOperationCancelled = true;
614+
}
615+
finally
616+
{
617+
Assert.True(finishedWithOperationCancelled);
618+
}
619+
}))
620+
{
621+
var clientBuffer = new byte[1024];
622+
623+
using (var client = new ClientWebSocket())
624+
{
625+
await client.ConnectAsync(new Uri($"ws://127.0.0.1:{port}/"), CancellationToken.None);
626+
627+
var socketWasAcceptedDidNotTimeout = socketWasAccepted.Wait(10000);
628+
Assert.True(socketWasAcceptedDidNotTimeout, "Socket was not accepted within the allotted time.");
629+
630+
await client.SendAsync(clientBuffer, WebSocketMessageType.Binary, false, default);
631+
632+
var firstReceiveOccuredDidNotTimeout = firstReceiveOccured.Wait(10000);
633+
Assert.True(firstReceiveOccuredDidNotTimeout, "First receive did not occur within the allotted time.");
634+
635+
cts.Cancel();
636+
637+
var operationWasCancelledDidNotTimeout = operationWasCancelled.Wait(1000); // Give it a second to process the abort.
638+
Assert.True(operationWasCancelledDidNotTimeout, "Cancel did not occur within the allotted time.");
639+
}
640+
}
641+
}
642+
502643
[Theory]
503644
[InlineData(HttpStatusCode.OK, null)]
504645
[InlineData(HttpStatusCode.Forbidden, "")]

0 commit comments

Comments
 (0)