Skip to content

Commit 5aebe65

Browse files
committed
add failing test and mitigation for OSS sharded sunbscribe behavior
1 parent 9ee5f6f commit 5aebe65

File tree

4 files changed

+298
-70
lines changed

4 files changed

+298
-70
lines changed

src/StackExchange.Redis/PhysicalConnection.cs

Lines changed: 128 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ internal sealed partial class PhysicalConnection : IDisposable
2929

3030
private const int DefaultRedisDatabaseCount = 16;
3131

32-
private static readonly CommandBytes message = "message", pmessage = "pmessage", smessage = "smessage";
32+
private static readonly CommandBytes message = "message", pmessage = "pmessage", smessage = "smessage",
33+
subscribe = "subscribe", sunsubscribe = "sunsubscribe";
3334

3435
private static readonly Message[] ReusableChangeDatabaseCommands = Enumerable.Range(0, DefaultRedisDatabaseCount).Select(
3536
i => Message.Create(i, CommandFlags.FireAndForget, RedisCommand.SELECT)).ToArray();
@@ -1669,6 +1670,36 @@ internal async ValueTask<bool> ConnectedAsync(Socket? socket, ILogger? log, Sock
16691670
}
16701671
}
16711672

1673+
private enum PushKind
1674+
{
1675+
None,
1676+
Message,
1677+
SMessage,
1678+
PMessage,
1679+
Subscribe,
1680+
SUnsubscribe,
1681+
}
1682+
private static PushKind GetPushKind(in Sequence<RawResult> result)
1683+
{
1684+
var len = result.Length;
1685+
if (len >= 1)
1686+
{
1687+
ref readonly RawResult kind = ref result[0];
1688+
if (len >= 3)
1689+
{
1690+
if (kind.IsEqual(message)) return PushKind.Message;
1691+
if (kind.IsEqual(smessage)) return PushKind.SMessage;
1692+
if (len >= 4)
1693+
{
1694+
if (kind.IsEqual(pmessage)) return PushKind.PMessage;
1695+
}
1696+
if (kind.IsEqual(sunsubscribe)) return PushKind.SUnsubscribe;
1697+
}
1698+
if (kind.IsEqual(subscribe)) return PushKind.Subscribe;
1699+
}
1700+
return PushKind.None;
1701+
}
1702+
16721703
private void MatchResult(in RawResult result)
16731704
{
16741705
// check to see if it could be an out-of-band pubsub message
@@ -1679,85 +1710,121 @@ private void MatchResult(in RawResult result)
16791710

16801711
// out of band message does not match to a queued message
16811712
var items = result.GetItems();
1682-
if (items.Length >= 3 && (items[0].IsEqual(message) || items[0].IsEqual(smessage)))
1713+
var kind = GetPushKind(items);
1714+
switch (kind)
16831715
{
1684-
_readStatus = items[0].IsEqual(message) ? ReadStatus.PubSubMessage : ReadStatus.PubSubSMessage;
1716+
case PushKind.Message:
1717+
case PushKind.SMessage:
1718+
_readStatus = kind is PushKind.Message ? ReadStatus.PubSubMessage : ReadStatus.PubSubSMessage;
16851719

1686-
// special-case the configuration change broadcasts (we don't keep that in the usual pub/sub registry)
1687-
var configChanged = muxer.ConfigurationChangedChannel;
1688-
if (configChanged != null && items[1].IsEqual(configChanged))
1689-
{
1690-
EndPoint? blame = null;
1691-
try
1720+
// special-case the configuration change broadcasts (we don't keep that in the usual pub/sub registry)
1721+
var configChanged = muxer.ConfigurationChangedChannel;
1722+
if (configChanged != null && items[1].IsEqual(configChanged))
16921723
{
1693-
if (!items[2].IsEqual(CommonReplies.wildcard))
1724+
EndPoint? blame = null;
1725+
try
1726+
{
1727+
if (!items[2].IsEqual(CommonReplies.wildcard))
1728+
{
1729+
// We don't want to fail here, just trying to identify
1730+
_ = Format.TryParseEndPoint(items[2].GetString(), out blame);
1731+
}
1732+
}
1733+
catch
16941734
{
1695-
// We don't want to fail here, just trying to identify
1696-
_ = Format.TryParseEndPoint(items[2].GetString(), out blame);
1735+
/* no biggie */
16971736
}
1737+
1738+
Trace("Configuration changed: " + Format.ToString(blame));
1739+
_readStatus = ReadStatus.Reconfigure;
1740+
muxer.ReconfigureIfNeeded(blame, true, "broadcast");
16981741
}
1699-
catch { /* no biggie */ }
1700-
Trace("Configuration changed: " + Format.ToString(blame));
1701-
_readStatus = ReadStatus.Reconfigure;
1702-
muxer.ReconfigureIfNeeded(blame, true, "broadcast");
1703-
}
17041742

1705-
// invoke the handlers
1706-
RedisChannel channel;
1707-
if (items[0].IsEqual(message))
1708-
{
1709-
channel = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.RedisChannelOptions.None);
1710-
Trace("MESSAGE: " + channel);
1711-
}
1712-
else // see check on outer-if that restricts to message / smessage
1713-
{
1714-
channel = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.RedisChannelOptions.Sharded);
1715-
Trace("SMESSAGE: " + channel);
1716-
}
1717-
if (!channel.IsNull)
1718-
{
1719-
if (TryGetPubSubPayload(items[2], out var payload))
1743+
// invoke the handlers
1744+
RedisChannel channel;
1745+
if (items[0].IsEqual(message))
17201746
{
1721-
_readStatus = ReadStatus.InvokePubSub;
1722-
muxer.OnMessage(channel, channel, payload);
1747+
channel = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.RedisChannelOptions.None);
1748+
Trace("MESSAGE: " + channel);
17231749
}
1724-
// could be multi-message: https://github.com/StackExchange/StackExchange.Redis/issues/2507
1725-
else if (TryGetMultiPubSubPayload(items[2], out var payloads))
1750+
else // see check on outer-if that restricts to message / smessage
17261751
{
1727-
_readStatus = ReadStatus.InvokePubSub;
1728-
muxer.OnMessage(channel, channel, payloads);
1752+
channel = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.RedisChannelOptions.Sharded);
1753+
Trace("SMESSAGE: " + channel);
1754+
}
1755+
1756+
if (!channel.IsNull)
1757+
{
1758+
if (TryGetPubSubPayload(items[2], out var payload))
1759+
{
1760+
_readStatus = ReadStatus.InvokePubSub;
1761+
muxer.OnMessage(channel, channel, payload);
1762+
}
1763+
// could be multi-message: https://github.com/StackExchange/StackExchange.Redis/issues/2507
1764+
else if (TryGetMultiPubSubPayload(items[2], out var payloads))
1765+
{
1766+
_readStatus = ReadStatus.InvokePubSub;
1767+
muxer.OnMessage(channel, channel, payloads);
1768+
}
17291769
}
1730-
}
1731-
return; // AND STOP PROCESSING!
1732-
}
1733-
else if (items.Length >= 4 && items[0].IsEqual(pmessage))
1734-
{
1735-
_readStatus = ReadStatus.PubSubPMessage;
17361770

1737-
var channel = items[2].AsRedisChannel(ChannelPrefix, RedisChannel.RedisChannelOptions.Pattern);
1771+
return; // AND STOP PROCESSING!
1772+
case PushKind.PMessage:
1773+
_readStatus = ReadStatus.PubSubPMessage;
17381774

1739-
Trace("PMESSAGE: " + channel);
1740-
if (!channel.IsNull)
1741-
{
1742-
if (TryGetPubSubPayload(items[3], out var payload))
1775+
channel = items[2].AsRedisChannel(ChannelPrefix, RedisChannel.RedisChannelOptions.Pattern);
1776+
1777+
Trace("PMESSAGE: " + channel);
1778+
if (!channel.IsNull)
17431779
{
1744-
var sub = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.RedisChannelOptions.Pattern);
1780+
if (TryGetPubSubPayload(items[3], out var payload))
1781+
{
1782+
var sub = items[1].AsRedisChannel(
1783+
ChannelPrefix,
1784+
RedisChannel.RedisChannelOptions.Pattern);
17451785

1746-
_readStatus = ReadStatus.InvokePubSub;
1747-
muxer.OnMessage(sub, channel, payload);
1786+
_readStatus = ReadStatus.InvokePubSub;
1787+
muxer.OnMessage(sub, channel, payload);
1788+
}
1789+
else if (TryGetMultiPubSubPayload(items[3], out var payloads))
1790+
{
1791+
var sub = items[1].AsRedisChannel(
1792+
ChannelPrefix,
1793+
RedisChannel.RedisChannelOptions.Pattern);
1794+
1795+
_readStatus = ReadStatus.InvokePubSub;
1796+
muxer.OnMessage(sub, channel, payloads);
1797+
}
17481798
}
1749-
else if (TryGetMultiPubSubPayload(items[3], out var payloads))
1750-
{
1751-
var sub = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.RedisChannelOptions.Pattern);
17521799

1753-
_readStatus = ReadStatus.InvokePubSub;
1754-
muxer.OnMessage(sub, channel, payloads);
1800+
break;
1801+
case PushKind.SUnsubscribe:
1802+
_readStatus = ReadStatus.PubSubSUnsubscribe;
1803+
channel = items[1].AsRedisChannel(ChannelPrefix, RedisChannel.RedisChannelOptions.Sharded);
1804+
var server = BridgeCouldBeNull?.ServerEndPoint;
1805+
if (server is not null && muxer.TryGetSubscription(channel, out var subscription))
1806+
{
1807+
if (subscription.GetCurrentServer() == server)
1808+
{
1809+
// definitely isn't this connection any more, but we were listening
1810+
subscription.SetCurrentServer(null);
1811+
muxer.ReconfigureIfNeeded(server.EndPoint, fromBroadcast: true, nameof(sunsubscribe));
1812+
}
17551813
}
1756-
}
1757-
return; // AND STOP PROCESSING!
1814+
break;
17581815
}
17591816

1760-
// if it didn't look like "[p|s]message", then we still need to process the pending queue
1817+
switch (kind)
1818+
{
1819+
// we recognized it a RESP2 OOB, or it was explicitly *any* RESP3 push notification
1820+
// (even if we didn't recognize the kind) - we're done; unless it is "subscribe", which
1821+
// is *technically* a push, but we still want to treat it as a response to the original message
1822+
case PushKind.None when result.Resp3Type != ResultType.Push:
1823+
case PushKind.Subscribe:
1824+
break; // continue, try to match to a pending message
1825+
default:
1826+
return; // we're done with this message (RESP3 OOB, or something we recognized)
1827+
}
17611828
}
17621829
Trace("Matching result...");
17631830

@@ -2168,6 +2235,7 @@ internal enum ReadStatus
21682235
MatchResultComplete,
21692236
ResetArena,
21702237
ProcessBufferComplete,
2238+
PubSubSUnsubscribe,
21712239
NA = -1,
21722240
}
21732241
private volatile ReadStatus _readStatus;

src/StackExchange.Redis/RedisSubscriber.cs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,14 @@ internal long EnsureSubscriptions(CommandFlags flags = CommandFlags.None)
145145
return count;
146146
}
147147

148+
internal void EnsureSubscription(Subscription sub, in RedisChannel channel, CommandFlags flags)
149+
{
150+
if (!sub.IsConnected)
151+
{
152+
DefaultSubscriber.EnsureSubscribedToServer(sub, channel, flags, true);
153+
}
154+
}
155+
148156
internal enum SubscriptionAction
149157
{
150158
Subscribe,
@@ -404,7 +412,7 @@ public ChannelMessageQueue Subscribe(RedisChannel channel, CommandFlags flags =
404412
return queue;
405413
}
406414

407-
public bool Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue>? handler, ChannelMessageQueue? queue, CommandFlags flags)
415+
private bool Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue>? handler, ChannelMessageQueue? queue, CommandFlags flags)
408416
{
409417
ThrowIfNull(channel);
410418
if (handler == null && queue == null) { return true; }
@@ -428,32 +436,34 @@ internal bool EnsureSubscribedToServer(Subscription sub, RedisChannel channel, C
428436
Task ISubscriber.SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags)
429437
=> SubscribeAsync(channel, handler, null, flags);
430438

431-
public async Task<ChannelMessageQueue> SubscribeAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None)
439+
Task<ChannelMessageQueue> ISubscriber.SubscribeAsync(RedisChannel channel, CommandFlags flags) => SubscribeAsync(channel, flags);
440+
441+
public async Task<ChannelMessageQueue> SubscribeAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None, ServerEndPoint? server = null)
432442
{
433443
var queue = new ChannelMessageQueue(channel, this);
434-
await SubscribeAsync(channel, null, queue, flags).ForAwait();
444+
await SubscribeAsync(channel, null, queue, flags, server).ForAwait();
435445
return queue;
436446
}
437447

438-
public Task<bool> SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue>? handler, ChannelMessageQueue? queue, CommandFlags flags)
448+
private Task<bool> SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue>? handler, ChannelMessageQueue? queue, CommandFlags flags, ServerEndPoint? server = null)
439449
{
440450
ThrowIfNull(channel);
441451
if (handler == null && queue == null) { return CompletedTask<bool>.Default(null); }
442452

443453
var sub = multiplexer.GetOrAddSubscription(channel, flags);
444454
sub.Add(handler, queue);
445-
return EnsureSubscribedToServerAsync(sub, channel, flags, false);
455+
return EnsureSubscribedToServerAsync(sub, channel, flags, false, server);
446456
}
447457

448-
public Task<bool> EnsureSubscribedToServerAsync(Subscription sub, RedisChannel channel, CommandFlags flags, bool internalCall)
458+
public Task<bool> EnsureSubscribedToServerAsync(Subscription sub, RedisChannel channel, CommandFlags flags, bool internalCall, ServerEndPoint? server = null)
449459
{
450460
if (sub.IsConnected) { return CompletedTask<bool>.Default(null); }
451461

452462
// TODO: Cleanup old hangers here?
453463
sub.SetCurrentServer(null); // we're not appropriately connected, so blank it out for eligible reconnection
454464
var message = sub.GetMessage(channel, SubscriptionAction.Subscribe, flags, internalCall);
455-
var selected = multiplexer.SelectServer(message);
456-
return ExecuteAsync(message, sub.Processor, selected);
465+
server ??= multiplexer.SelectServer(message);
466+
return ExecuteAsync(message, sub.Processor, server);
457467
}
458468

459469
public EndPoint? SubscribedEndpoint(RedisChannel channel) => multiplexer.GetSubscribedServer(channel)?.EndPoint;

src/StackExchange.Redis/ServerSelectionStrategy.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ private ServerEndPoint[] MapForMutation()
328328
return arr;
329329
}
330330

331-
private ServerEndPoint? Select(int slot, RedisCommand command, CommandFlags flags, bool allowDisconnected)
331+
internal ServerEndPoint? Select(int slot, RedisCommand command, CommandFlags flags, bool allowDisconnected)
332332
{
333333
// Only interested in primary/replica preferences
334334
flags = Message.GetPrimaryReplicaFlags(flags);

0 commit comments

Comments
 (0)