Skip to content

Commit 79a2b3a

Browse files
BrennanConroywtgodbe
authored andcommitted
Merged PR 31898: Avoid Redis pattern matching
Avoid Redis pattern matching
1 parent e49d53e commit 79a2b3a

File tree

5 files changed

+127
-10
lines changed

5 files changed

+127
-10
lines changed

src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Runtime.CompilerServices;
5+
using StackExchange.Redis;
56

67
namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal
78
{
@@ -71,5 +72,7 @@ public string Ack(string serverName)
7172
{
7273
return _prefix + ":internal:ack:" + serverName;
7374
}
75+
76+
public static RedisChannel GetChannel(string channelName) => new RedisChannel(channelName, RedisChannel.PatternMode.Literal);
7477
}
7578
}

src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ public override Task OnDisconnectedAsync(HubConnectionContext connection)
114114

115115
var connectionChannel = _channels.Connection(connection.ConnectionId);
116116
RedisLog.Unsubscribe(_logger, connectionChannel);
117-
tasks.Add(_bus!.UnsubscribeAsync(connectionChannel));
117+
tasks.Add(_bus!.UnsubscribeAsync(RedisChannels.GetChannel(connectionChannel)));
118118

119119
var feature = connection.Features.Get<IRedisFeature>()!;
120120
var groupNames = feature.Groups;
@@ -315,7 +315,7 @@ private async Task PublishAsync(string channel, byte[] payload)
315315
{
316316
await EnsureRedisServerConnection();
317317
RedisLog.PublishToChannel(_logger, channel);
318-
await _bus!.PublishAsync(channel, payload);
318+
await _bus!.PublishAsync(RedisChannels.GetChannel(channel), payload);
319319
}
320320

321321
private Task AddGroupAsyncCore(HubConnectionContext connection, string groupName)
@@ -347,7 +347,7 @@ private async Task RemoveGroupAsyncCore(HubConnectionContext connection, string
347347
await _groups.RemoveSubscriptionAsync(groupChannel, connection, channelName =>
348348
{
349349
RedisLog.Unsubscribe(_logger, channelName);
350-
return _bus!.UnsubscribeAsync(channelName);
350+
return _bus!.UnsubscribeAsync(RedisChannels.GetChannel(channelName));
351351
});
352352

353353
var feature = connection.Features.Get<IRedisFeature>()!;
@@ -379,7 +379,7 @@ private Task RemoveUserAsync(HubConnectionContext connection)
379379
return _users.RemoveSubscriptionAsync(userChannel, connection, channelName =>
380380
{
381381
RedisLog.Unsubscribe(_logger, channelName);
382-
return _bus!.UnsubscribeAsync(channelName);
382+
return _bus!.UnsubscribeAsync(RedisChannels.GetChannel(channelName));
383383
});
384384
}
385385

@@ -396,7 +396,7 @@ public void Dispose()
396396
private async Task SubscribeToAll()
397397
{
398398
RedisLog.Subscribing(_logger, _channels.All);
399-
var channel = await _bus!.SubscribeAsync(_channels.All);
399+
var channel = await _bus!.SubscribeAsync(RedisChannels.GetChannel(_channels.All));
400400
channel.OnMessage(async channelMessage =>
401401
{
402402
try
@@ -426,7 +426,7 @@ private async Task SubscribeToAll()
426426

427427
private async Task SubscribeToGroupManagementChannel()
428428
{
429-
var channel = await _bus!.SubscribeAsync(_channels.GroupManagement);
429+
var channel = await _bus!.SubscribeAsync(RedisChannels.GetChannel(_channels.GroupManagement));
430430
channel.OnMessage(async channelMessage =>
431431
{
432432
try
@@ -463,7 +463,7 @@ private async Task SubscribeToGroupManagementChannel()
463463
private async Task SubscribeToAckChannel()
464464
{
465465
// Create server specific channel in order to send an ack to a single server
466-
var channel = await _bus!.SubscribeAsync(_channels.Ack(_serverName));
466+
var channel = await _bus!.SubscribeAsync(RedisChannels.GetChannel(_channels.Ack(_serverName)));
467467
channel.OnMessage(channelMessage =>
468468
{
469469
var ackId = _protocol.ReadAck((byte[])channelMessage.Message);
@@ -477,7 +477,7 @@ private async Task SubscribeToConnection(HubConnectionContext connection)
477477
var connectionChannel = _channels.Connection(connection.ConnectionId);
478478

479479
RedisLog.Subscribing(_logger, connectionChannel);
480-
var channel = await _bus!.SubscribeAsync(connectionChannel);
480+
var channel = await _bus!.SubscribeAsync(RedisChannels.GetChannel(connectionChannel));
481481
channel.OnMessage(channelMessage =>
482482
{
483483
var invocation = _protocol.ReadInvocation((byte[])channelMessage.Message);
@@ -492,7 +492,7 @@ private Task SubscribeToUser(HubConnectionContext connection)
492492
return _users.AddSubscriptionAsync(userChannel, connection, async (channelName, subscriptions) =>
493493
{
494494
RedisLog.Subscribing(_logger, channelName);
495-
var channel = await _bus!.SubscribeAsync(channelName);
495+
var channel = await _bus!.SubscribeAsync(RedisChannels.GetChannel(channelName));
496496
channel.OnMessage(async channelMessage =>
497497
{
498498
try
@@ -518,7 +518,7 @@ private Task SubscribeToUser(HubConnectionContext connection)
518518
private async Task SubscribeToGroupAsync(string groupChannel, HubConnectionStore groupConnections)
519519
{
520520
RedisLog.Subscribing(_logger, groupChannel);
521-
var channel = await _bus!.SubscribeAsync(groupChannel);
521+
var channel = await _bus!.SubscribeAsync(RedisChannels.GetChannel(groupChannel));
522522
channel.OnMessage(async (channelMessage) =>
523523
{
524524
try

src/SignalR/server/StackExchangeRedis/test/RedisEndToEnd.cs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,73 @@ public async Task CanSendAndReceiveUserMessagesWhenOneConnectionWithUserDisconne
149149
}
150150
}
151151

152+
[ConditionalTheory]
153+
[SkipIfDockerNotPresent]
154+
[MemberData(nameof(TransportTypesAndProtocolTypes))]
155+
public async Task HubConnectionCanSendAndReceiveGroupMessagesGroupNameWithPatternIsTreatedAsLiteral(HttpTransportType transportType, string protocolName)
156+
{
157+
using (StartVerifiableLog())
158+
{
159+
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);
160+
161+
var connection = CreateConnection(_serverFixture.FirstServer.Url + "/echo", transportType, protocol, LoggerFactory);
162+
var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/echo", transportType, protocol, LoggerFactory);
163+
164+
var tcs = new TaskCompletionSource<string>();
165+
connection.On<string>("Echo", message => tcs.TrySetResult(message));
166+
var tcs2 = new TaskCompletionSource<string>();
167+
secondConnection.On<string>("Echo", message => tcs2.TrySetResult(message));
168+
169+
var groupName = $"TestGroup_{transportType}_{protocolName}_{Guid.NewGuid()}";
170+
171+
await secondConnection.StartAsync().DefaultTimeout();
172+
await connection.StartAsync().DefaultTimeout();
173+
await connection.InvokeAsync("AddSelfToGroup", "*").DefaultTimeout();
174+
await secondConnection.InvokeAsync("AddSelfToGroup", groupName).DefaultTimeout();
175+
await connection.InvokeAsync("EchoGroup", groupName, "Hello, World!").DefaultTimeout();
176+
177+
Assert.Equal("Hello, World!", await tcs2.Task.DefaultTimeout());
178+
Assert.False(tcs.Task.IsCompleted);
179+
180+
await connection.InvokeAsync("EchoGroup", "*", "Hello, World!").DefaultTimeout();
181+
Assert.Equal("Hello, World!", await tcs.Task.DefaultTimeout());
182+
183+
await connection.DisposeAsync().DefaultTimeout();
184+
}
185+
}
186+
187+
[ConditionalTheory]
188+
[SkipIfDockerNotPresent]
189+
[MemberData(nameof(TransportTypesAndProtocolTypes))]
190+
public async Task CanSendAndReceiveUserMessagesUserNameWithPatternIsTreatedAsLiteral(HttpTransportType transportType, string protocolName)
191+
{
192+
using (StartVerifiableLog())
193+
{
194+
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);
195+
196+
var connection = CreateConnection(_serverFixture.FirstServer.Url + "/echo", transportType, protocol, LoggerFactory, userName: "*");
197+
var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/echo", transportType, protocol, LoggerFactory, userName: "userA");
198+
199+
var tcs = new TaskCompletionSource<string>();
200+
connection.On<string>("Echo", message => tcs.TrySetResult(message));
201+
var tcs2 = new TaskCompletionSource<string>();
202+
secondConnection.On<string>("Echo", message => tcs2.TrySetResult(message));
203+
204+
await secondConnection.StartAsync().DefaultTimeout();
205+
await connection.StartAsync().DefaultTimeout();
206+
await connection.InvokeAsync("EchoUser", "userA", "Hello, World!").DefaultTimeout();
207+
208+
Assert.Equal("Hello, World!", await tcs2.Task.DefaultTimeout());
209+
Assert.False(tcs.Task.IsCompleted);
210+
211+
await connection.InvokeAsync("EchoUser", "*", "Hello, World!").DefaultTimeout();
212+
Assert.Equal("Hello, World!", await tcs.Task.DefaultTimeout());
213+
214+
await connection.DisposeAsync().DefaultTimeout();
215+
await secondConnection.DisposeAsync().DefaultTimeout();
216+
}
217+
}
218+
152219
private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null)
153220
{
154221
var hubConnectionBuilder = new HubConnectionBuilder()

src/SignalR/server/StackExchangeRedis/test/RedisHubLifetimeManagerTests.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,26 @@ public async Task CamelCasedJsonIsPreservedAcrossRedisBoundary()
8383
}
8484
}
8585

86+
// Smoke test that Debug.Asserts in TestSubscriber aren't hit
87+
[Fact]
88+
public async Task PatternGroupAndUser()
89+
{
90+
var server = new TestRedisServer();
91+
using (var client = new TestClient())
92+
{
93+
var manager = CreateLifetimeManager(server);
94+
95+
var connection = HubConnectionContextUtils.Create(client.Connection);
96+
connection.UserIdentifier = "*";
97+
98+
await manager.OnConnectedAsync(connection).DefaultTimeout();
99+
100+
var groupName = "*";
101+
102+
await manager.AddToGroupAsync(connection.ConnectionId, groupName).DefaultTimeout();
103+
}
104+
}
105+
86106
public override TestRedisServer CreateBackplane()
87107
{
88108
return new TestRedisServer();

src/SignalR/server/StackExchangeRedis/test/TestConnectionMultiplexer.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
using System;
55
using System.Collections.Concurrent;
66
using System.Collections.Generic;
7+
using System.Diagnostics;
78
using System.IO;
89
using System.Net;
910
using System.Reflection;
1011
using System.Threading;
1112
using System.Threading.Tasks;
1213
using StackExchange.Redis;
1314
using StackExchange.Redis.Profiling;
15+
using Xunit;
1416

1517
namespace Microsoft.AspNetCore.SignalR.Tests
1618
{
@@ -230,6 +232,8 @@ public class TestRedisServer
230232

231233
public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None)
232234
{
235+
AssertRedisChannel(channel);
236+
233237
if (_subscriptions.TryGetValue(channel, out var handlers))
234238
{
235239
foreach (var (_, handler) in handlers)
@@ -243,6 +247,8 @@ public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags
243247

244248
public void Subscribe(ChannelMessageQueue messageQueue, int subscriberId, CommandFlags flags = CommandFlags.None)
245249
{
250+
AssertRedisChannel(messageQueue.Channel);
251+
246252
Action<RedisChannel, RedisValue> handler = (channel, value) =>
247253
{
248254
// Workaround for https://github.com/StackExchange/StackExchange.Redis/issues/969
@@ -260,11 +266,20 @@ public void Subscribe(ChannelMessageQueue messageQueue, int subscriberId, Comman
260266

261267
public void Unsubscribe(RedisChannel channel, int subscriberId, CommandFlags flags = CommandFlags.None)
262268
{
269+
AssertRedisChannel(channel);
270+
263271
if (_subscriptions.TryGetValue(channel, out var list))
264272
{
265273
list.RemoveAll((item) => item.Item1 == subscriberId);
266274
}
267275
}
276+
277+
internal static void AssertRedisChannel(RedisChannel channel)
278+
{
279+
var patternField = typeof(RedisChannel).GetField("IsPatternBased", BindingFlags.NonPublic | BindingFlags.Instance);
280+
Assert.NotNull(patternField);
281+
Assert.False((bool)patternField.GetValue(channel));
282+
}
268283
}
269284

270285
public class TestSubscriber : ISubscriber
@@ -310,11 +325,15 @@ public Task<TimeSpan> PingAsync(CommandFlags flags = CommandFlags.None)
310325

311326
public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None)
312327
{
328+
TestRedisServer.AssertRedisChannel(channel);
329+
313330
return _server.Publish(channel, message, flags);
314331
}
315332

316333
public async Task<long> PublishAsync(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None)
317334
{
335+
TestRedisServer.AssertRedisChannel(channel);
336+
318337
await Task.Yield();
319338
return Publish(channel, message, flags);
320339
}
@@ -326,6 +345,8 @@ public void Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue> han
326345

327346
public Task SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None)
328347
{
348+
TestRedisServer.AssertRedisChannel(channel);
349+
329350
Subscribe(channel, handler, flags);
330351
return Task.CompletedTask;
331352
}
@@ -342,6 +363,8 @@ public bool TryWait(Task task)
342363

343364
public void Unsubscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler = null, CommandFlags flags = CommandFlags.None)
344365
{
366+
TestRedisServer.AssertRedisChannel(channel);
367+
345368
_server.Unsubscribe(channel, _id, flags);
346369
}
347370

@@ -357,6 +380,8 @@ public Task UnsubscribeAllAsync(CommandFlags flags = CommandFlags.None)
357380

358381
public Task UnsubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler = null, CommandFlags flags = CommandFlags.None)
359382
{
383+
TestRedisServer.AssertRedisChannel(channel);
384+
360385
Unsubscribe(channel, handler, flags);
361386
return Task.CompletedTask;
362387
}
@@ -391,6 +416,8 @@ public ChannelMessageQueue Subscribe(RedisChannel channel, CommandFlags flags =
391416

392417
public Task<ChannelMessageQueue> SubscribeAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None)
393418
{
419+
TestRedisServer.AssertRedisChannel(channel);
420+
394421
var t = Subscribe(channel, flags);
395422
return Task.FromResult(t);
396423
}

0 commit comments

Comments
 (0)