Skip to content

Commit d7a61e9

Browse files
committed
Ensure single reader on channel
Keep a single reader task to prevent concurrent reading operations that would fail. Also adds a display name for channels so they are easier to troubleshoot. Fixes #2
1 parent be99215 commit d7a61e9

File tree

5 files changed

+163
-38
lines changed

5 files changed

+163
-38
lines changed

WebSocketChannel.sln

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "WebSocketChannel", "src\Web
1313
EndProject
1414
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tests", "src\Tests\Tests.csproj", "{517F1129-4EA6-46FA-827B-42CF5EB0DE09}"
1515
EndProject
16+
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Benchmark", "src\Benchmark\Benchmark.csproj", "{694ED796-BC51-4B41-85B0-961E79A424DC}"
17+
EndProject
1618
Global
1719
GlobalSection(SolutionConfigurationPlatforms) = preSolution
1820
Debug|Any CPU = Debug|Any CPU
@@ -27,6 +29,10 @@ Global
2729
{517F1129-4EA6-46FA-827B-42CF5EB0DE09}.Debug|Any CPU.Build.0 = Debug|Any CPU
2830
{517F1129-4EA6-46FA-827B-42CF5EB0DE09}.Release|Any CPU.ActiveCfg = Release|Any CPU
2931
{517F1129-4EA6-46FA-827B-42CF5EB0DE09}.Release|Any CPU.Build.0 = Release|Any CPU
32+
{694ED796-BC51-4B41-85B0-961E79A424DC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
33+
{694ED796-BC51-4B41-85B0-961E79A424DC}.Debug|Any CPU.Build.0 = Debug|Any CPU
34+
{694ED796-BC51-4B41-85B0-961E79A424DC}.Release|Any CPU.ActiveCfg = Release|Any CPU
35+
{694ED796-BC51-4B41-85B0-961E79A424DC}.Release|Any CPU.Build.0 = Release|Any CPU
3036
EndGlobalSection
3137
GlobalSection(SolutionProperties) = preSolution
3238
HideSolutionNode = FALSE

src/Benchmark/Benchmark.csproj

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<OutputType>Exe</OutputType>
5+
<TargetFramework>net6.0</TargetFramework>
6+
<ImplicitUsings>true</ImplicitUsings>
7+
<GenerateDocumentationFile>false</GenerateDocumentationFile>
8+
</PropertyGroup>
9+
10+
<ItemGroup>
11+
<PackageReference Include="BenchmarkDotNet" Version="0.13.1" />
12+
<PackageReference Include="BenchmarkDotNet.Diagnostics.Windows" Version="0.13.1" />
13+
</ItemGroup>
14+
15+
<ItemGroup>
16+
<ProjectReference Include="..\Tests\Tests.csproj" />
17+
<ProjectReference Include="..\WebSocketChannel\WebSocketChannel.csproj" />
18+
</ItemGroup>
19+
20+
</Project>

src/Benchmark/Program.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
using System.Net.WebSockets;
2+
using System.Text;
3+
using BenchmarkDotNet.Attributes;
4+
using BenchmarkDotNet.Diagnostics.Windows.Configs;
5+
using BenchmarkDotNet.Running;
6+
using Devlooped.Net;
7+
8+
BenchmarkRunner.Run<Benchmarks>();
9+
10+
[NativeMemoryProfiler]
11+
[MemoryDiagnoser]
12+
public class Benchmarks
13+
{
14+
[Params(1000, 2000, 5000/*, 10000, 20000*/)]
15+
public int RunTime = 1000;
16+
17+
[Benchmark]
18+
public async Task ReadAllBytes()
19+
{
20+
var cts = new CancellationTokenSource(RunTime);
21+
using var server = WebSocketServer.Create();
22+
using var client = new ClientWebSocket();
23+
await client.ConnectAsync(server.Uri, CancellationToken.None);
24+
var channel = client.CreateChannel();
25+
26+
try
27+
{
28+
_ = Task.Run(async () =>
29+
{
30+
var mem = Encoding.UTF8.GetBytes(Guid.NewGuid().ToString()).AsMemory();
31+
while (!cts.IsCancellationRequested)
32+
await channel.Writer.WriteAsync(mem);
33+
34+
await server.DisposeAsync();
35+
});
36+
37+
await foreach (var item in channel.Reader.ReadAllAsync(cts.Token))
38+
{
39+
Console.WriteLine(Encoding.UTF8.GetString(item.Span));
40+
}
41+
}
42+
catch (OperationCanceledException)
43+
{
44+
}
45+
}
46+
}

src/WebSocketChannel/WebSocketChannel.cs

Lines changed: 87 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Diagnostics.CodeAnalysis;
1+
using System.Diagnostics;
2+
using System.Diagnostics.CodeAnalysis;
23

34
namespace Devlooped.Net;
45

@@ -17,29 +18,41 @@ static partial class WebSocketChannel
1718
/// purposes.
1819
/// </summary>
1920
/// <param name="webSocket">The <see cref="WebSocket"/> to create the channel over.</param>
21+
/// <param name="displayName">Optional friendly name to identify this channel while debugging or troubleshooting.</param>
2022
/// <returns>A channel to read/write the given <paramref name="webSocket"/>.</returns>
21-
public static Channel<ReadOnlyMemory<byte>> Create(WebSocket webSocket)
22-
=> new DefaultWebSocketChannel(webSocket);
23+
public static Channel<ReadOnlyMemory<byte>> Create(WebSocket webSocket, string? displayName = default)
24+
=> new DefaultWebSocketChannel(webSocket, displayName);
2325

2426
class DefaultWebSocketChannel : Channel<ReadOnlyMemory<byte>>
2527
{
26-
static readonly Exception defaultDoneWritting = new Exception(nameof(defaultDoneWritting));
28+
#if DEBUG
29+
static readonly TimeSpan closeTimeout = TimeSpan.FromMilliseconds(
30+
Debugger.IsAttached ? int.MaxValue : 250);
31+
#else
32+
static readonly TimeSpan closeTimeout = TimeSpan.FromMilliseconds(250);
33+
#endif
34+
static readonly Exception defaultDoneWritting = new(nameof(defaultDoneWritting));
2735
static readonly Exception socketClosed = new WebSocketException(WebSocketError.ConnectionClosedPrematurely, "WebSocket was closed by the remote party.");
2836

29-
static readonly TimeSpan closeTimeout = TimeSpan.FromMilliseconds(250);
37+
readonly CancellationTokenSource completionCancellation = new();
3038
readonly TaskCompletionSource<bool> completion = new();
3139
readonly object syncObj = new();
3240
Exception? done;
3341

3442
WebSocket webSocket;
3543

36-
public DefaultWebSocketChannel(WebSocket webSocket)
44+
public DefaultWebSocketChannel(WebSocket webSocket, string? displayName = default)
3745
{
3846
this.webSocket = webSocket;
47+
DisplayName = displayName;
3948
Reader = new WebSocketChannelReader(this);
4049
Writer = new WebSocketChannelWriter(this);
4150
}
4251

52+
public string? DisplayName { get; }
53+
54+
public override string ToString() => DisplayName ?? base.ToString();
55+
4356
void Complete()
4457
{
4558
if (done is OperationCanceledException oce)
@@ -66,6 +79,8 @@ void Complete()
6679
;
6780
}
6881
}
82+
83+
completionCancellation.Cancel();
6984
}
7085

7186
async ValueTask Close(string? description = default)
@@ -76,14 +91,25 @@ async ValueTask Close(string? description = default)
7691
webSocket.CloseOutputAsync(description != null ? WebSocketCloseStatus.InternalServerError : WebSocketCloseStatus.NormalClosure, description, default);
7792

7893
// Don't wait indefinitely for the close to be acknowledged
79-
await Task.WhenAny(closeTask, Task.Delay(closeTimeout));
94+
await Task.WhenAny(closeTask, Task.Delay(closeTimeout)).ConfigureAwait(false);
8095
}
8196

8297
class WebSocketChannelReader : ChannelReader<ReadOnlyMemory<byte>>
8398
{
99+
#if DEBUG
100+
static readonly TimeSpan tryReadTimeout = TimeSpan.FromMilliseconds(
101+
Debugger.IsAttached ? int.MaxValue : 250);
102+
#else
84103
static readonly TimeSpan tryReadTimeout = TimeSpan.FromMilliseconds(250);
104+
#endif
105+
85106
readonly DefaultWebSocketChannel channel;
86-
readonly SemaphoreSlim semaphore = new SemaphoreSlim(1, 1);
107+
readonly object syncObj = new();
108+
109+
readonly SemaphoreSlim semaphore = new(1, 1);
110+
111+
IMemoryOwner<byte>? memoryOwner;
112+
ValueTask<ReadOnlyMemory<byte>>? readingTask = default;
87113

88114
public WebSocketChannelReader(DefaultWebSocketChannel channel) => this.channel = channel;
89115

@@ -113,15 +139,53 @@ public override bool TryRead([MaybeNullWhen(false)] out ReadOnlyMemory<byte> ite
113139
if (channel.webSocket.State != WebSocketState.Open)
114140
return false;
115141

142+
// We keep a singleton ongoing reading task at a time (single reader),
143+
// since that's how the underlying websocket has to be used (no concurrent
144+
// Receive calls should be performed).
145+
if (readingTask == null)
146+
{
147+
lock (syncObj)
148+
{
149+
if (readingTask == null)
150+
readingTask = ReadCoreAsync(channel.completionCancellation.Token);
151+
}
152+
}
153+
154+
// Don't lock the call for more than a small timeout time. This allows
155+
// this method, which is not async and cannot be cancelled to signal that
156+
// it couldn't read within an acceptable timeout. This is important considering
157+
// the ReadAllAsync extension method on ChannelReader<T>, which is implemented
158+
// as follows:
159+
//while (await WaitToReadAsync())
160+
// while (TryRead(out T? item))
161+
// yield return item;
162+
// NOTE: our WaitToReadAsync will continue to return true as long as the
163+
// websocket is open, so the underlying reading task can complete.
116164
var cts = new CancellationTokenSource(tryReadTimeout);
117-
var result = ReadCoreAsync(cts.Token);
118-
while (!result.IsCompleted)
165+
while (readingTask != null && readingTask?.IsCompleted != true && !cts.IsCancellationRequested)
119166
;
120167

121-
if (result.IsCompletedSuccessfully)
122-
item = result.Result;
168+
if (readingTask == null)
169+
return false;
123170

124-
return result.IsCompletedSuccessfully && channel.webSocket.State == WebSocketState.Open;
171+
lock (syncObj)
172+
{
173+
if (readingTask == null)
174+
return false;
175+
176+
if (readingTask.Value.IsCompletedSuccessfully == true &&
177+
readingTask.Value.Result.IsEmpty == false)
178+
{
179+
item = readingTask.Value.Result;
180+
readingTask = null;
181+
return true;
182+
}
183+
184+
if (readingTask.Value.IsCompleted == true)
185+
readingTask = null;
186+
187+
return false;
188+
}
125189
}
126190

127191
public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
@@ -141,44 +205,32 @@ public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationTo
141205

142206
async ValueTask<ReadOnlyMemory<byte>> ReadCoreAsync(CancellationToken cancellation)
143207
{
144-
await semaphore.WaitAsync(cancellation);
208+
await semaphore.WaitAsync(cancellation).ConfigureAwait(false);
145209
try
146210
{
147-
using var owner = MemoryPool<byte>.Shared.Rent(512);
148-
var received = await channel.webSocket.ReceiveAsync(owner.Memory, cancellation).ConfigureAwait(false);
211+
memoryOwner?.Dispose();
212+
memoryOwner = MemoryPool<byte>.Shared.Rent(512);
213+
var received = await channel.webSocket.ReceiveAsync(memoryOwner.Memory, cancellation).ConfigureAwait(false);
149214
var count = received.Count;
150215
while (!cancellation.IsCancellationRequested && !received.EndOfMessage && received.MessageType != WebSocketMessageType.Close)
151216
{
152217
if (received.Count == 0)
153218
break;
154219

155-
received = await channel.webSocket.ReceiveAsync(owner.Memory, cancellation).ConfigureAwait(false);
220+
received = await channel.webSocket.ReceiveAsync(memoryOwner.Memory.Slice(count), cancellation).ConfigureAwait(false);
156221
count += received.Count;
157222
}
158223

159224
cancellation.ThrowIfCancellationRequested();
160225

161226
// We didn't get a complete message, we can't flush partial message.
162227
if (received.MessageType == WebSocketMessageType.Close)
163-
{
164-
// Server requested closure.
165-
lock (channel.syncObj)
166-
{
167-
if (channel.done == null)
168-
{
169-
channel.done = socketClosed;
170-
channel.Complete();
171-
}
172-
}
173228
throw socketClosed;
174-
}
175229

176-
// Only return from the whole buffer, the slice of bytes that we actually received.
177-
return owner.Memory.Slice(0, count);
230+
return memoryOwner.Memory.Slice(0, count);
178231
}
179232
// Don't re-throw the expected socketClosed exception we throw when Close received.
180-
catch (Exception ex) when (ex != socketClosed &&
181-
(ex is WebSocketException || ex is InvalidOperationException))
233+
catch (Exception ex) when (ex is WebSocketException || ex is InvalidOperationException)
182234
{
183235
// We consider premature closure just as an explicit closure.
184236
if (ex is WebSocketException wex && wex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
@@ -203,7 +255,7 @@ async ValueTask<ReadOnlyMemory<byte>> ReadCoreAsync(CancellationToken cancellati
203255
class WebSocketChannelWriter : ChannelWriter<ReadOnlyMemory<byte>>
204256
{
205257
readonly DefaultWebSocketChannel channel;
206-
readonly SemaphoreSlim semaphore = new SemaphoreSlim(1, 1);
258+
readonly SemaphoreSlim semaphore = new(1, 1);
207259

208260
public WebSocketChannelWriter(DefaultWebSocketChannel channel) => this.channel = channel;
209261

@@ -258,10 +310,10 @@ public override ValueTask<bool> WaitToWriteAsync(CancellationToken cancellationT
258310

259311
async ValueTask WriteAsyncCore(ReadOnlyMemory<byte> item, CancellationToken cancellationToken = default)
260312
{
261-
await semaphore.WaitAsync(cancellationToken);
313+
await semaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
262314
try
263315
{
264-
await channel.webSocket.SendAsync(item, WebSocketMessageType.Binary, true, cancellationToken);
316+
await channel.webSocket.SendAsync(item, WebSocketMessageType.Binary, true, cancellationToken).ConfigureAwait(false);
265317
}
266318
finally
267319
{

src/WebSocketChannel/WebSocketExtensions.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
namespace System.Net.WebSockets;
55

66
/// <summary>
7-
/// Provides the <see cref="CreateChannel(WebSocket)"/> extension method for
7+
/// Provides the <see cref="CreateChannel"/> extension method for
88
/// reading/writing to a <see cref="WebSocket"/> using the <see cref="Channel{T}"/>
99
/// API.
1010
/// </summary>
@@ -16,8 +16,9 @@ static partial class WebSocketExtensions
1616
/// purposes.
1717
/// </summary>
1818
/// <param name="webSocket">The <see cref="WebSocket"/> to create the channel over.</param>
19+
/// <param name="displayName">Optional friendly name to identify this channel while debugging or troubleshooting.</param>
1920
/// <returns>A channel to read/write the given <paramref name="webSocket"/>.</returns>
20-
public static Channel<ReadOnlyMemory<byte>> CreateChannel(this WebSocket webSocket)
21-
=> WebSocketChannel.Create(webSocket);
21+
public static Channel<ReadOnlyMemory<byte>> CreateChannel(this WebSocket webSocket, string? displayName = default)
22+
=> WebSocketChannel.Create(webSocket, displayName);
2223
}
2324

0 commit comments

Comments
 (0)