Skip to content

Commit a71b4a4

Browse files
authored
Merge pull request #49 from AArnott/fix46
Fix race condition when accepting a channel by name that is canceled
2 parents 6d33d95 + b5688f8 commit a71b4a4

File tree

1 file changed

+60
-43
lines changed

1 file changed

+60
-43
lines changed

src/Nerdbank.Streams/MultiplexingStream.cs

Lines changed: 60 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -426,55 +426,62 @@ public async Task<Channel> AcceptChannelAsync(string name, ChannelOptions option
426426
Requires.NotNull(name, nameof(name));
427427
Verify.NotDisposed(this);
428428

429-
Channel channel = null;
430-
TaskCompletionSource<Channel> pendingAcceptChannel = null;
431-
lock (this.syncObject)
429+
while (true)
432430
{
433-
if (this.channelsOfferedByThemByName.TryGetValue(name, out var channelsOfferedByThem))
431+
Channel channel = null;
432+
TaskCompletionSource<Channel> pendingAcceptChannel = null;
433+
lock (this.syncObject)
434434
{
435-
while (channel == null && channelsOfferedByThem.Count > 0)
435+
if (this.channelsOfferedByThemByName.TryGetValue(name, out var channelsOfferedByThem))
436436
{
437-
channel = channelsOfferedByThem.Dequeue();
438-
if (channel.Acceptance.IsCompleted)
437+
while (channel == null && channelsOfferedByThem.Count > 0)
439438
{
440-
channel = null;
441-
continue;
439+
channel = channelsOfferedByThem.Dequeue();
440+
if (channel.Acceptance.IsCompleted)
441+
{
442+
channel = null;
443+
continue;
444+
}
445+
446+
if (this.TraceSource.Switch.ShouldTrace(TraceEventType.Information))
447+
{
448+
this.TraceSource.TraceEvent(TraceEventType.Information, (int)TraceEventId.AcceptChannelAlreadyOffered, "Accepting channel {1} \"{0}\" which is already offered by the other side.", name, channel.Id);
449+
}
442450
}
451+
}
443452

453+
if (channel == null)
454+
{
444455
if (this.TraceSource.Switch.ShouldTrace(TraceEventType.Information))
445456
{
446-
this.TraceSource.TraceEvent(TraceEventType.Information, (int)TraceEventId.AcceptChannelAlreadyOffered, "Accepting channel {1} \"{0}\" which is already offered by the other side.", name, channel.Id);
457+
this.TraceSource.TraceEvent(TraceEventType.Information, (int)TraceEventId.AcceptChannelWaiting, "Waiting to accept channel \"{0}\", when offered by the other side.", name);
447458
}
459+
460+
if (!this.acceptingChannels.TryGetValue(name, out var acceptingChannels))
461+
{
462+
this.acceptingChannels.Add(name, acceptingChannels = new Queue<TaskCompletionSource<Channel>>());
463+
}
464+
465+
pendingAcceptChannel = new TaskCompletionSource<Channel>(options);
466+
acceptingChannels.Enqueue(pendingAcceptChannel);
448467
}
449468
}
450469

451-
if (channel == null)
470+
if (channel != null)
452471
{
453-
if (this.TraceSource.Switch.ShouldTrace(TraceEventType.Information))
454-
{
455-
this.TraceSource.TraceEvent(TraceEventType.Information, (int)TraceEventId.AcceptChannelWaiting, "Waiting to accept channel \"{0}\", when offered by the other side.", name);
456-
}
457-
458-
if (!this.acceptingChannels.TryGetValue(name, out var acceptingChannels))
472+
// In a race condition with the channel offer being canceled, we may fail to accept the channel.
473+
// In that case, we'll just loop back around and wait for another one.
474+
if (this.TryAcceptChannel(channel, options))
459475
{
460-
this.acceptingChannels.Add(name, acceptingChannels = new Queue<TaskCompletionSource<Channel>>());
476+
return channel;
461477
}
462-
463-
pendingAcceptChannel = new TaskCompletionSource<Channel>(options);
464-
acceptingChannels.Enqueue(pendingAcceptChannel);
465478
}
466-
}
467-
468-
if (channel != null)
469-
{
470-
this.AcceptChannelOrThrow(channel, options);
471-
return channel;
472-
}
473-
else
474-
{
475-
using (cancellationToken.Register(this.AcceptChannelCanceled, Tuple.Create(pendingAcceptChannel, name), false))
479+
else
476480
{
477-
return await pendingAcceptChannel.Task.ConfigureAwait(false);
481+
using (cancellationToken.Register(this.AcceptChannelCanceled, Tuple.Create(pendingAcceptChannel, name), false))
482+
{
483+
return await pendingAcceptChannel.Task.ConfigureAwait(false);
484+
}
478485
}
479486
}
480487
}
@@ -784,25 +791,35 @@ private async ValueTask OnOffer(int channelId, Memory<byte> payloadBuffer, Cance
784791
this.OnChannelOffered(args);
785792
}
786793

787-
private void AcceptChannelOrThrow(Channel channel, ChannelOptions options)
794+
private bool TryAcceptChannel(Channel channel, ChannelOptions options)
788795
{
789796
Requires.NotNull(channel, nameof(channel));
790797

791798
if (channel.TryAcceptOffer(options))
792799
{
793800
this.SendFrame(ControlCode.OfferAccepted, channel.Id);
801+
return true;
794802
}
795-
else if (channel.IsAccepted)
796-
{
797-
throw new InvalidOperationException("Channel is already accepted.");
798-
}
799-
else if (channel.IsRejectedOrCanceled)
800-
{
801-
throw new InvalidOperationException("Channel is no longer available for acceptance.");
802-
}
803-
else
803+
804+
return false;
805+
}
806+
807+
private void AcceptChannelOrThrow(Channel channel, ChannelOptions options)
808+
{
809+
if (!this.TryAcceptChannel(channel, options))
804810
{
805-
throw new InvalidOperationException("Channel could not be accepted.");
811+
if (channel.IsAccepted)
812+
{
813+
throw new InvalidOperationException("Channel is already accepted.");
814+
}
815+
else if (channel.IsRejectedOrCanceled)
816+
{
817+
throw new InvalidOperationException("Channel is no longer available for acceptance.");
818+
}
819+
else
820+
{
821+
throw new InvalidOperationException("Channel could not be accepted.");
822+
}
806823
}
807824
}
808825

0 commit comments

Comments
 (0)