Skip to content
35 changes: 32 additions & 3 deletions src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,18 @@ public override Task AddToGroupAsync(string connectionId, string groupName, Canc
return Task.CompletedTask;
}

_groups.Add(connection, groupName);
// Track groups in the connection object
lock (connection.GroupNames)
{
if (!connection.GroupNames.Add(groupName))
{
// Connection already in group
return Task.CompletedTask;
}

_groups.Add(connection, groupName);
}

// Connection disconnected while adding to group, remove it in case the Add was called after OnDisconnectedAsync removed items from the group
if (connection.ConnectionAborted.IsCancellationRequested)
{
Expand All @@ -64,7 +75,17 @@ public override Task RemoveFromGroupAsync(string connectionId, string groupName,
return Task.CompletedTask;
}

_groups.Remove(connectionId, groupName);
// Remove from previously saved groups
lock (connection.GroupNames)
{
if (!connection.GroupNames.Remove(groupName))
{
// Connection not in group
return Task.CompletedTask;
}

_groups.Remove(connectionId, groupName);
}

return Task.CompletedTask;
}
Expand Down Expand Up @@ -277,8 +298,16 @@ public override Task OnConnectedAsync(HubConnectionContext connection)
/// <inheritdoc />
public override Task OnDisconnectedAsync(HubConnectionContext connection)
{
lock (connection.GroupNames)
{
// Remove from tracked groups one by one
foreach (var groupName in connection.GroupNames)
{
_groups.Remove(connection.ConnectionId, groupName);
}
}

_connections.Remove(connection);
_groups.RemoveDisconnectedConnection(connection.ConnectionId);

return Task.CompletedTask;
}
Expand Down
3 changes: 3 additions & 0 deletions src/SignalR/server/Core/src/HubConnectionContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ public partial class HubConnectionContext
[MemberNotNullWhen(true, nameof(_messageBuffer))]
internal bool UsingStatefulReconnect() => _useStatefulReconnect;

// Tracks groups that the connection has been added to
internal HashSet<string> GroupNames { get; } = new HashSet<string>();

/// <summary>
/// Initializes a new instance of the <see cref="HubConnectionContext"/> class.
/// </summary>
Expand Down
10 changes: 0 additions & 10 deletions src/SignalR/server/Core/src/Internal/HubGroupList.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System.Collections;
using System.Collections.Concurrent;
using System.Linq;

namespace Microsoft.AspNetCore.SignalR.Internal;

Expand Down Expand Up @@ -43,15 +42,6 @@ public void Remove(string connectionId, string groupName)
}
}

public void RemoveDisconnectedConnection(string connectionId)
{
var groupNames = _groups.Where(x => x.Value.ContainsKey(connectionId)).Select(x => x.Key);
foreach (var groupName in groupNames)
{
Remove(connectionId, groupName);
}
}

public int Count => _groups.Count;

public IEnumerator<ConcurrentDictionary<string, HubConnectionContext>> GetEnumerator()
Expand Down