Skip to content

Commit 3e902ca

Browse files
authored
WebSocket: observe exceptions in WaitForServerToCloseConnectionAsync (dotnet#114689)
Includes a refactor renaming and moving the methods responsible for logging and observing the exceptions. The renaming is for consistency with `HttpConnectionBase.LogExceptions` in `System.Net.Http`.
1 parent a37502b commit 3e902ca

File tree

4 files changed

+94
-50
lines changed

4 files changed

+94
-50
lines changed

src/libraries/System.Net.WebSockets.Client/tests/CloseTest.cs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
using Xunit;
1414
using Xunit.Abstractions;
15+
using Microsoft.DotNet.RemoteExecutor;
1516

1617
namespace System.Net.WebSockets.Client.Tests
1718
{
@@ -523,5 +524,42 @@ await Assert.ThrowsAnyAsync<OperationCanceledException>(async () =>
523524

524525
}), new LoopbackServer.Options { WebSocketEndpoint = true });
525526
}
527+
528+
// Regression test for https://github.com/dotnet/runtime/issues/80116.
529+
[OuterLoop("Uses Task.Delay")]
530+
[ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
531+
public async Task CloseHandshake_ExceptionsAreObserved()
532+
{
533+
await RemoteExecutor.Invoke(static (typeName) =>
534+
{
535+
CloseTest test = (CloseTest)Activator.CreateInstance(typeof(CloseTest).Assembly.GetType(typeName), new object[] { null });
536+
using CancellationTokenSource timeoutCts = new CancellationTokenSource(TimeOutMilliseconds);
537+
538+
Exception unobserved = null;
539+
TaskScheduler.UnobservedTaskException += (obj, args) =>
540+
{
541+
unobserved = args.Exception;
542+
};
543+
544+
TaskCompletionSource clientCompleted = new TaskCompletionSource();
545+
546+
return LoopbackWebSocketServer.RunAsync(async (clientWs, ct) =>
547+
{
548+
await clientWs.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", ct);
549+
await clientWs.ReceiveAsync(new byte[16], ct);
550+
await Task.Delay(1500);
551+
GC.Collect(2);
552+
GC.WaitForPendingFinalizers();
553+
clientCompleted.SetResult();
554+
Assert.Null(unobserved);
555+
},
556+
async (serverWs, ct) =>
557+
{
558+
await serverWs.ReceiveAsync(new byte[16], ct);
559+
await serverWs.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", ct);
560+
await clientCompleted.Task;
561+
}, new LoopbackWebSocketServer.Options(HttpVersion.Version11, true, test.GetInvoker()), timeoutCts.Token);
562+
}, GetType().FullName).DisposeAsync();
563+
}
526564
}
527565
}

src/libraries/System.Net.WebSockets.Client/tests/System.Net.WebSockets.Client.Tests.csproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
<PropertyGroup>
55
<StringResourcesPath>../src/Resources/Strings.resx</StringResourcesPath>
6+
<IncludeRemoteExecutor>true</IncludeRemoteExecutor>
67
<TargetFrameworks>$(NetCoreAppCurrent);$(NetCoreAppCurrent)-browser</TargetFrameworks>
78
<DefineConstants>$(DefineConstants);NETSTANDARD</DefineConstants>
89
</PropertyGroup>
@@ -46,6 +47,7 @@
4647
<Compile Include="$(CommonTestPath)System\Net\Http\HuffmanEncoder.cs" Link="Common\System\Net\Http\HuffmanEncoder.cs" />
4748
<Compile Include="$(CommonTestPath)System\Net\Http\HPackEncoder.cs" Link="Common\System\Net\Http\HPackEncoder.cs" />
4849
<Compile Include="$(CommonTestPath)System\Net\Http\GenericLoopbackServer.cs" Link="Common\System\Net\Http\GenericLoopbackServer.cs" />
50+
<Compile Include="$(CommonTestPath)System\Net\RemoteExecutorExtensions.cs" Link="Common\System\Net\RemoteExecutorExtensions.cs" />
4951
<Compile Include="$(CommonTestPath)System\Threading\Tasks\TaskTimeoutExtensions.cs" Link="Common\System\Threading\Tasks\TaskTimeoutExtensions.cs" />
5052
<Compile Include="AbortTest.cs" />
5153
<Compile Include="AbortTest.Loopback.cs" />

src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.KeepAlive.cs

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ private void UnsolicitedPongHeartBeat()
2929
{
3030
if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);
3131

32-
Observe(
32+
LogExceptions(
3333
TrySendKeepAliveFrameAsync(MessageOpcode.Pong));
3434
}
3535

@@ -98,7 +98,7 @@ private void KeepAlivePingHeartBeat()
9898

9999
if (shouldSendPing)
100100
{
101-
Observe(
101+
LogExceptions(
102102
SendPingAsync(pingPayload));
103103
}
104104
}
@@ -122,52 +122,6 @@ private async ValueTask SendPingAsync(long pingPayload)
122122
if (NetEventSource.Log.IsEnabled()) NetEventSource.KeepAlivePingSent(this, pingPayload);
123123
}
124124

125-
// "Observe" either a ValueTask result, or any exception, ignoring it
126-
// to prevent the unobserved exception event from being raised.
127-
private void Observe(ValueTask t)
128-
{
129-
if (t.IsCompletedSuccessfully)
130-
{
131-
t.GetAwaiter().GetResult();
132-
}
133-
else
134-
{
135-
Observe(t.AsTask());
136-
}
137-
}
138-
139-
// "Observe" any exception, ignoring it to prevent the unobserved task
140-
// exception event from being raised.
141-
private void Observe(Task t)
142-
{
143-
if (t.IsCompleted)
144-
{
145-
if (t.IsFaulted)
146-
{
147-
LogFaulted(t, this);
148-
}
149-
}
150-
else
151-
{
152-
t.ContinueWith(
153-
LogFaulted,
154-
this,
155-
CancellationToken.None,
156-
TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously,
157-
TaskScheduler.Default);
158-
}
159-
160-
static void LogFaulted(Task task, object? thisObj)
161-
{
162-
Debug.Assert(task.IsFaulted);
163-
164-
// accessing exception to observe it regardless of whether the tracing is enabled
165-
Exception e = task.Exception!.InnerException!;
166-
167-
if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(thisObj, e);
168-
}
169-
}
170-
171125
private sealed class KeepAlivePingState
172126
{
173127
internal const int PingPayloadSize = sizeof(long);

src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/ManagedWebSocket.cs

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,23 +1144,27 @@ private async ValueTask WaitForServerToCloseConnectionAsync(CancellationToken ca
11441144
// additional data, but at this point we're about to close the connection and we're just stalling
11451145
// to try to get the server to close first.
11461146
ValueTask<int> finalReadTask = _stream.ReadAsync(_receiveBuffer, cancellationToken);
1147+
11471148
if (finalReadTask.IsCompletedSuccessfully)
11481149
{
11491150
finalReadTask.GetAwaiter().GetResult();
11501151
}
11511152
else
11521153
{
11531154
const int WaitForCloseTimeoutMs = 1_000; // arbitrary amount of time to give the server (same duration as .NET Framework)
1155+
Task task = finalReadTask.AsTask();
1156+
11541157
try
11551158
{
11561159
#pragma warning disable CA2016 // Token was already provided to the ReadAsync
1157-
await finalReadTask.AsTask().WaitAsync(TimeSpan.FromMilliseconds(WaitForCloseTimeoutMs)).ConfigureAwait(false);
1160+
await task.WaitAsync(TimeSpan.FromMilliseconds(WaitForCloseTimeoutMs)).ConfigureAwait(false);
11581161
#pragma warning restore CA2016
11591162
}
11601163
catch
11611164
{
1165+
// Eat any resulting exceptions. We were going to close the connection, anyway.
1166+
LogExceptions(task);
11621167
Abort();
1163-
// Eat any resulting exceptions. We were going to close the connection, anyway.
11641168
}
11651169
}
11661170
}
@@ -1851,6 +1855,52 @@ private static bool TryValidateUtf8(ReadOnlySpan<byte> span, bool endOfMessage,
18511855
return !endOfMessage || !state.SequenceInProgress;
18521856
}
18531857

1858+
// "Observe" either a ValueTask result, or any exception, logging and ignoring it
1859+
// to prevent the unobserved exception event from being raised.
1860+
private void LogExceptions(ValueTask t)
1861+
{
1862+
if (t.IsCompletedSuccessfully)
1863+
{
1864+
t.GetAwaiter().GetResult();
1865+
}
1866+
else
1867+
{
1868+
LogExceptions(t.AsTask());
1869+
}
1870+
}
1871+
1872+
// "Observe" and log any exception, ignoring it to prevent the unobserved task
1873+
// exception event from being raised.
1874+
private void LogExceptions(Task t)
1875+
{
1876+
if (t.IsCompleted)
1877+
{
1878+
if (t.IsFaulted)
1879+
{
1880+
LogFaulted(t, this);
1881+
}
1882+
}
1883+
else
1884+
{
1885+
t.ContinueWith(
1886+
LogFaulted,
1887+
this,
1888+
CancellationToken.None,
1889+
TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously,
1890+
TaskScheduler.Default);
1891+
}
1892+
1893+
static void LogFaulted(Task task, object? thisObj)
1894+
{
1895+
Debug.Assert(task.IsFaulted);
1896+
1897+
// accessing exception to observe it regardless of whether the tracing is enabled
1898+
Exception e = task.Exception!.InnerException!;
1899+
1900+
if (NetEventSource.Log.IsEnabled()) NetEventSource.TraceException(thisObj, e);
1901+
}
1902+
}
1903+
18541904
private sealed class Utf8MessageState
18551905
{
18561906
internal bool SequenceInProgress;

0 commit comments

Comments
 (0)