Skip to content

Commit ba2f74e

Browse files
Fix possible missed completion if cancellation occured before WriteAll.
1 parent ab3697a commit ba2f74e

File tree

6 files changed

+114
-16
lines changed

6 files changed

+114
-16
lines changed

Open.ChannelExtensions.Tests/ExceptionTests.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Linq;
33
using System.Threading;
4+
using System.Threading.Channels;
45
using System.Threading.Tasks;
56
using Xunit;
67

@@ -71,4 +72,35 @@ await range
7172
Assert.Equal(1, count);
7273
Assert.NotEqual(testSize, total);
7374
}
75+
76+
[Fact]
77+
public static void ChannelClosed()
78+
{
79+
var channel = Channel.CreateBounded<int>(new BoundedChannelOptions(1000)
80+
{
81+
SingleWriter = true,
82+
SingleReader = true,
83+
});
84+
channel.Writer.Complete();
85+
// Needs to throw immediately if true.
86+
Assert.Throws<ChannelClosedException>(() => channel.Writer.WaitToWriteAndThrowIfClosedAsync());
87+
}
88+
89+
[Fact]
90+
public static async Task WriteAllThrowIfClosed()
91+
{
92+
var channel = Channel.CreateBounded<int>(new BoundedChannelOptions(1000)
93+
{
94+
SingleWriter = true,
95+
SingleReader = true,
96+
});
97+
var reader = channel.Source(Enumerable.Range(0, 10_000));
98+
await reader.ReadAll(_ => { });
99+
100+
await Assert.ThrowsAsync<ChannelClosedException>(async ()=>
101+
{
102+
channel.Source(Enumerable.Range(0, 10_000), out var completion);
103+
await completion;
104+
});
105+
}
74106
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using System;
2+
using System.Linq;
3+
using System.Threading;
4+
using System.Threading.Tasks;
5+
using Xunit;
6+
7+
namespace Open.ChannelExtensions.Tests;
8+
public static class SourceTests
9+
{
10+
[Fact]
11+
public static async Task ToChannelCancelledAfterwriteStarts()
12+
{
13+
var cts = new CancellationTokenSource();
14+
var reader = Enumerable.Range(0, 10_000).ToChannel(10, true, cts.Token);
15+
cts.Cancel();
16+
17+
try
18+
{
19+
await reader.ReadAll(_ => { }, cts.Token);
20+
}
21+
catch (OperationCanceledException)
22+
{ }
23+
24+
await reader.ReadAll(_ => { });
25+
await Assert.ThrowsAsync<TaskCanceledException>(()=>reader.Completion);
26+
}
27+
28+
[Fact]
29+
public static async Task ToChannelCancelledBeforeWriteStarts()
30+
{
31+
var cts = new CancellationTokenSource();
32+
cts.Cancel();
33+
var reader = Enumerable.Range(0, 10_000).ToChannel(10, true, cts.Token);
34+
35+
var count = await reader.ReadAll(_ => { });
36+
Assert.Equal(0, count);
37+
await Assert.ThrowsAsync<TaskCanceledException>(() => reader.Completion);
38+
}
39+
}

Open.ChannelExtensions/Extensions.Source.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ public static partial class Extensions
2020
/// <param name="completion">The underlying ValueTask used to pass the data from the source to the channel.</param>
2121
/// <param name="deferredExecution">If true, calls await Task.Yield() before writing to the channel.</param>
2222
/// <param name="cancellationToken">An optional cancellation token.</param>
23+
/// <remarks>Calling this method does not throw if the channel is already closed.</remarks>
2324
/// <returns>The channel reader.</returns>
2425
public static ChannelReader<TRead> SourceAsync<TWrite, TRead>(
2526
this Channel<TWrite, TRead> target,

Open.ChannelExtensions/Extensions.Write.cs

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ public static async ValueTask<long> WriteAllAsync<T>(
3434
if (source is null) throw new ArgumentNullException(nameof(source));
3535
Contract.EndContractBlock();
3636

37-
await target.WaitToWriteAndThrowIfClosedAsync(ChannelClosedMessage, deferredExecution, cancellationToken).ConfigureAwait(false);
38-
3937
try
4038
{
39+
await target
40+
.WaitToWriteAndThrowIfClosedAsync(ChannelClosedMessage, deferredExecution, cancellationToken)
41+
.ConfigureAwait(false);
42+
4143
long count = 0;
4244
var next = new ValueTask();
4345
foreach (ValueTask<T> e in source)
@@ -50,19 +52,20 @@ public static async ValueTask<long> WriteAllAsync<T>(
5052
await next.ConfigureAwait(false);
5153
return count;
5254
}
55+
catch(ChannelClosedException) { throw; }
5356
catch (Exception ex)
5457
{
5558
if (complete)
5659
{
57-
target.Complete(ex);
60+
target.TryComplete(ex);
5861
complete = false;
5962
}
6063
throw;
6164
}
6265
finally
6366
{
6467
if (complete)
65-
target.Complete();
68+
target.TryComplete();
6669
}
6770
}
6871

@@ -246,11 +249,11 @@ public static async ValueTask<long> WriteAllLines(
246249
if (source is null) throw new ArgumentNullException(nameof(source));
247250
Contract.EndContractBlock();
248251

249-
ValueTask next = target.WaitToWriteAndThrowIfClosedAsync(ChannelClosedMessage, deferredExecution, cancellationToken);
250-
await next.ConfigureAwait(false);
251-
252252
try
253253
{
254+
ValueTask next = target.WaitToWriteAndThrowIfClosedAsync(ChannelClosedMessage, deferredExecution, cancellationToken);
255+
await next.ConfigureAwait(false);
256+
254257
long count = 0;
255258
bool more = false; // if it completed and actually returned false, no need to bubble the cancellation since it actually completed.
256259
while (!cancellationToken.IsCancellationRequested)
@@ -275,19 +278,20 @@ public static async ValueTask<long> WriteAllLines(
275278
if (more) cancellationToken.ThrowIfCancellationRequested();
276279
return count;
277280
}
281+
catch (ChannelClosedException) { throw; }
278282
catch (Exception ex)
279283
{
280284
if (complete)
281285
{
282-
target.Complete(ex);
286+
target.TryComplete(ex);
283287
complete = false;
284288
}
285289
throw;
286290
}
287291
finally
288292
{
289293
if (complete)
290-
target.Complete();
294+
target.TryComplete();
291295
}
292296
}
293297

@@ -330,12 +334,12 @@ public static async ValueTask<long> WriteAllAsync<T>(
330334
if (source is null) throw new ArgumentNullException(nameof(source));
331335
Contract.EndContractBlock();
332336

333-
await target
334-
.WaitToWriteAndThrowIfClosedAsync(ChannelClosedMessage, deferredExecution, cancellationToken)
335-
.ConfigureAwait(false);
336-
337337
try
338338
{
339+
await target
340+
.WaitToWriteAndThrowIfClosedAsync(ChannelClosedMessage, deferredExecution, cancellationToken)
341+
.ConfigureAwait(false);
342+
339343
long count = 0;
340344
var next = new ValueTask();
341345
await foreach (T? value in source)
@@ -347,19 +351,20 @@ await target
347351
await next.ConfigureAwait(false);
348352
return count;
349353
}
354+
catch (ChannelClosedException) { throw; }
350355
catch (Exception ex)
351356
{
352357
if (complete)
353358
{
354-
target.Complete(ex);
359+
target.TryComplete(ex);
355360
complete = false;
356361
}
357362
throw;
358363
}
359364
finally
360365
{
361366
if (complete)
362-
target.Complete();
367+
target.TryComplete();
363368
}
364369
}
365370

Open.ChannelExtensions/Extensions._.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,27 @@ internal static IEnumerable<ValueTask<T>> WrapValueTask<T>(this IEnumerable<Func
6666
yield return new ValueTask<T>(e());
6767
}
6868

69+
70+
/// <summary>
71+
/// Uses <see cref="ChannelWriter{T}.WaitToWriteAsync(CancellationToken)"/> to peek and see if the channel can still be written to.
72+
/// </summary>
73+
/// <typeparam name="T">The type being written to the channel</typeparam>
74+
/// <param name="writer">The channel writer.</param>
75+
/// <param name="ifClosedMessage">The message to include with the ChannelClosedException if thrown.</param>
76+
/// <exception cref="ChannelClosedException">If the channel writer will no longer accept messages.</exception>
77+
public static void ThrowIfClosed<T>(this ChannelWriter<T> writer, string? ifClosedMessage = null)
78+
{
79+
if (writer is null) throw new ArgumentNullException(nameof(writer));
80+
ValueTask<bool> waitForWrite = writer.WaitToWriteAsync();
81+
if (!waitForWrite.IsCompletedSuccessfully || waitForWrite.Result)
82+
return;
83+
84+
if (string.IsNullOrWhiteSpace(ifClosedMessage))
85+
throw new ChannelClosedException();
86+
87+
throw new ChannelClosedException(ifClosedMessage);
88+
}
89+
6990
/// <summary>
7091
/// Waits for opportunity to write to a channel and throws a ChannelClosedException if the channel is closed.
7192
/// </summary>

Open.ChannelExtensions/Open.ChannelExtensions.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
<RepositoryType>git</RepositoryType>
2323
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
2424
<GenerateDocumentationFile>true</GenerateDocumentationFile>
25-
<Version>6.2.1</Version>
25+
<Version>6.2.2</Version>
2626
<PackageReleaseNotes></PackageReleaseNotes>
2727
<PackageLicenseExpression>MIT</PackageLicenseExpression>
2828
<PublishRepositoryUrl>true</PublishRepositoryUrl>

0 commit comments

Comments
 (0)