Skip to content

Commit f36ad68

Browse files
authored
fix(server): Use correct messages for sharded pubsub (#5818)
1 parent 4e49c90 commit f36ad68

16 files changed

+141
-139
lines changed

src/facade/command_id.h

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,6 @@ class CommandId {
8888
opt_mask_ |= flag;
8989
}
9090

91-
// PUBLISH/SUBSCRIBE/UNSUBSCRIBE variant
92-
bool IsPubSub() const {
93-
return is_pub_sub_;
94-
}
95-
96-
// PSUBSCRIBE/PUNSUBSCRIBE variant
97-
bool IsPSub() const {
98-
return is_p_pub_sub_;
99-
}
100-
101-
// SSUBSCRIBE/SUNSUBSCRIBE variant
102-
bool IsShardedPSub() const {
103-
return is_sharded_pub_sub_;
104-
}
105-
10691
protected:
10792
std::string name_;
10893

@@ -119,10 +104,6 @@ class CommandId {
119104

120105
// Whether the command can only be used by admin connections.
121106
bool restricted_ = false;
122-
123-
bool is_pub_sub_ = false;
124-
bool is_sharded_pub_sub_ = false;
125-
bool is_p_pub_sub_ = false;
126107
};
127108

128109
} // namespace facade

src/facade/dragonfly_connection.cc

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ void Connection::AsyncOperations::operator()(const AclUpdateMessage& msg) {
519519
}
520520

521521
void Connection::AsyncOperations::operator()(const PubMessage& pub_msg) {
522-
RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder;
522+
RedisReplyBuilder* rb = static_cast<RedisReplyBuilder*>(builder);
523523

524524
// Discard stale messages to not break the protocol after exiting "pubsub" mode.
525525
// Even after removing all subscriptions, we still can receive messages delayed
@@ -529,20 +529,19 @@ void Connection::AsyncOperations::operator()(const PubMessage& pub_msg) {
529529
!base::_in(pub_msg.channel, {"unsubscribe", "punsubscribe"}))
530530
return;
531531

532-
if (pub_msg.should_unsubscribe) {
533-
rbuilder->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
534-
rbuilder->SendBulkString("unsubscribe");
535-
rbuilder->SendBulkString(pub_msg.channel);
536-
rbuilder->SendLong(0);
537-
auto* cntx = self->cntx();
538-
cntx->Unsubscribe(pub_msg.channel);
532+
if (pub_msg.force_unsubscribe) {
533+
rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
534+
rb->SendBulkString("sunsubscribe");
535+
rb->SendBulkString(pub_msg.channel);
536+
rb->SendLong(0);
537+
self->cntx()->Unsubscribe(pub_msg.channel);
539538
return;
540539
}
541540

542541
unsigned i = 0;
543542
array<string_view, 4> arr;
544543
if (pub_msg.pattern.empty()) {
545-
arr[i++] = "message";
544+
arr[i++] = pub_msg.is_sharded ? "smessage" : "message";
546545
} else {
547546
arr[i++] = "pmessage";
548547
arr[i++] = pub_msg.pattern;
@@ -551,8 +550,8 @@ void Connection::AsyncOperations::operator()(const PubMessage& pub_msg) {
551550
arr[i++] = pub_msg.channel;
552551
arr[i++] = pub_msg.message;
553552

554-
rbuilder->SendBulkStrArr(absl::Span<string_view>{arr.data(), i},
555-
RedisReplyBuilder::CollectionType::PUSH);
553+
rb->SendBulkStrArr(absl::Span<string_view>{arr.data(), i},
554+
RedisReplyBuilder::CollectionType::PUSH);
556555
}
557556

558557
void Connection::AsyncOperations::operator()(Connection::PipelineMessage& msg) {

src/facade/dragonfly_connection.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ class Connection : public util::Connection {
7373
std::string pattern; // non-empty for pattern subscriber
7474
std::shared_ptr<char[]> buf; // stores channel name and message
7575
std::string_view channel, message; // channel and message parts from buf
76-
bool should_unsubscribe = false; // unsubscribe from channel after sending the message
76+
bool is_sharded = false;
77+
78+
// Unsubscribe simultaneously when sending unsubscribe message. Used for cluster migrations
79+
bool force_unsubscribe = false;
7780
};
7881

7982
// Pipeline message, accumulated Redis command to be executed.

src/facade/facade.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,6 @@ CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first
138138
first_key_(first_key),
139139
last_key_(last_key),
140140
acl_categories_(acl_categories) {
141-
if (name_ == "PUBLISH" || name_ == "SUBSCRIBE" || name_ == "UNSUBSCRIBE") {
142-
is_pub_sub_ = true;
143-
} else if (name_ == "PSUBSCRIBE" || name_ == "PUNSUBSCRIBE") {
144-
is_p_pub_sub_ = true;
145-
} else if (name_ == "SPUBLISH" || name_ == "SSUBSCRIBE" || name_ == "SUNSUBSCRIBE") {
146-
is_sharded_pub_sub_ = true;
147-
}
148141
}
149142

150143
static bool ParseHumanReadableBytes(std::string_view str, int64_t* num_bytes) {

src/server/acl/validator.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ bool ValidateCommand(const std::vector<uint64_t>& acl_commands, const CommandId&
7878

7979
std::pair<bool, AclLog::Reason> auth_res;
8080

81-
if (id.IsPubSub() || id.IsShardedPSub()) {
82-
auth_res = IsPubSubCommandAuthorized(false, cntx.acl_commands, cntx.pub_sub, tail_args, id);
83-
} else if (id.IsPSub()) {
84-
auth_res = IsPubSubCommandAuthorized(true, cntx.acl_commands, cntx.pub_sub, tail_args, id);
81+
if (auto pkind = id.PubSubKind(); pkind) {
82+
bool is_pattern = *pkind == CO::PubSubKind::PATTERN;
83+
auth_res =
84+
IsPubSubCommandAuthorized(is_pattern, cntx.acl_commands, cntx.pub_sub, tail_args, id);
8585
} else {
8686
auth_res = IsUserAllowedToInvokeCommandGeneric(cntx, id, tail_args);
8787
}

src/server/channel_store.cc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ using namespace std;
1919
namespace {
2020

2121
// Build functor for sending messages to connection
22-
auto BuildSender(string_view channel, facade::ArgRange messages, bool unsubscribe = false) {
22+
auto BuildSender(string_view channel, facade::ArgRange messages, bool sharded = false,
23+
bool unsubscribe = false) {
2324
absl::FixedArray<string_view, 1> views(messages.Size());
2425
size_t messages_size = accumulate(messages.begin(), messages.end(), 0,
2526
[](int sum, string_view str) { return sum + str.size(); });
@@ -36,11 +37,12 @@ auto BuildSender(string_view channel, facade::ArgRange messages, bool unsubscrib
3637
}
3738
}
3839

39-
return [channel, buf = std::move(buf), views = std::move(views), unsubscribe](
40+
return [channel, buf = std::move(buf), views = std::move(views), sharded, unsubscribe](
4041
facade::Connection* conn, string pattern) {
4142
string_view channel_view{buf.get(), channel.size()};
4243
for (std::string_view message_view : views) {
43-
conn->SendPubMessageAsync({std::move(pattern), buf, channel_view, message_view, unsubscribe});
44+
conn->SendPubMessageAsync(
45+
{std::move(pattern), buf, channel_view, message_view, sharded, unsubscribe});
4446
}
4547
};
4648
}
@@ -117,7 +119,8 @@ void ChannelStore::Destroy() {
117119

118120
ChannelStore::ControlBlock ChannelStore::control_block;
119121

120-
unsigned ChannelStore::SendMessages(std::string_view channel, facade::ArgRange messages) const {
122+
unsigned ChannelStore::SendMessages(std::string_view channel, facade::ArgRange messages,
123+
bool sharded) const {
121124
vector<Subscriber> subscribers = FetchSubscribers(channel);
122125
if (subscribers.empty())
123126
return 0;
@@ -144,7 +147,7 @@ unsigned ChannelStore::SendMessages(std::string_view channel, facade::ArgRange m
144147
}
145148

146149
auto subscribers_ptr = make_shared<decltype(subscribers)>(std::move(subscribers));
147-
auto cb = [subscribers_ptr, send = BuildSender(channel, messages)](unsigned idx, auto*) {
150+
auto cb = [subscribers_ptr, send = BuildSender(channel, messages, sharded)](unsigned idx, auto*) {
148151
auto it = lower_bound(subscribers_ptr->begin(), subscribers_ptr->end(), idx,
149152
ChannelStore::Subscriber::ByThreadId);
150153
while (it != subscribers_ptr->end() && it->LastKnownThreadId() == idx) {
@@ -217,13 +220,14 @@ void ChannelStore::UnsubscribeAfterClusterSlotMigration(const cluster::SlotSet&
217220
csu.ApplyAndUnsubscribe();
218221
}
219222

223+
// TODO: Reuse common code with Send function
224+
// TODO: Find proper solution to hacky `force_unsubscribe` flag or at least move logic out of io
220225
void ChannelStore::UnsubscribeConnectionsFromDeletedSlots(const ChannelsSubMap& sub_map,
221226
uint32_t idx) {
222-
const bool should_unsubscribe = true;
223227
for (const auto& [channel, subscribers] : sub_map) {
224228
// ignored by pub sub handler because should_unsubscribe is true
225229
std::string msg = "__ignore__";
226-
auto send = BuildSender(channel, {facade::ArgSlice{msg}}, should_unsubscribe);
230+
auto send = BuildSender(channel, {facade::ArgSlice{msg}}, false, true);
227231

228232
auto it = lower_bound(subscribers.begin(), subscribers.end(), idx,
229233
ChannelStore::Subscriber::ByThreadId);

src/server/channel_store.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class ChannelStore {
5959
ChannelStore();
6060

6161
// Send messages to channel, block on connection backpressure
62-
unsigned SendMessages(std::string_view channel, facade::ArgRange messages) const;
62+
unsigned SendMessages(std::string_view channel, facade::ArgRange messages, bool sharded) const;
6363

6464
// Fetch all subscribers for channel, including matching patterns.
6565
std::vector<Subscriber> FetchSubscribers(std::string_view channel) const;

src/server/cluster/cluster_family_test.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,34 @@ TEST_F(ClusterFamilyTest, ClusterModePubSubNotAllowed) {
670670
ErrArg("PUNSUBSCRIBE is not supported in cluster mode yet"));
671671
}
672672

673+
// SSUBSCRIBE and SPUBLISH work in cluster mode
674+
TEST_F(ClusterFamilyTest, ClusterModePubSub) {
675+
single_response_ = false;
676+
ConfigSingleNodeCluster(GetMyId());
677+
678+
// Ssubscribe works as expected
679+
auto resp = pp_->at(1)->Await([&] { return Run({"SSUBSCRIBE", "cluster-channel"}); });
680+
EXPECT_THAT(resp, RespElementsAre("ssubscribe", "cluster-channel", IntArg(1)));
681+
682+
// Send-receive a single message
683+
resp = pp_->at(0)->Await([&] {
684+
return Run({"SPUBLISH", "cluster-channel", "a simple message"});
685+
});
686+
EXPECT_THAT(resp, IntArg(1));
687+
688+
pp_->AwaitFiberOnAll([](util::ProactorBase* pb) {});
689+
690+
ASSERT_EQ(1, SubscriberMessagesLen("IO1"));
691+
const auto& msg = GetPublishedMessage("IO1", 0);
692+
EXPECT_TRUE(msg.is_sharded);
693+
EXPECT_EQ("cluster-channel", msg.channel);
694+
EXPECT_EQ("a simple message", msg.message);
695+
696+
// Sunsubscribe
697+
resp = pp_->at(1)->Await([&] { return Run({"SUNSUBSCRIBE", "cluster-channel"}); });
698+
EXPECT_THAT(resp, RespElementsAre("sunsubscribe", "cluster-channel", IntArg(0)));
699+
}
700+
673701
TEST_F(ClusterFamilyTest, ClusterFirstConfigCallDropsEntriesNotOwnedByNode) {
674702
InitWithDbFilename();
675703

src/server/command_registry.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44

55
#include "server/command_registry.h"
66

7+
#include <absl/container/inlined_vector.h>
8+
#include <absl/strings/match.h>
9+
#include <absl/strings/str_cat.h>
710
#include <absl/strings/str_split.h>
811
#include <absl/time/clock.h>
912

10-
#include "absl/container/inlined_vector.h"
11-
#include "absl/strings/match.h"
12-
#include "absl/strings/str_cat.h"
1313
#include "base/bits.h"
1414
#include "base/flags.h"
1515
#include "base/logging.h"
16+
#include "base/stl_util.h"
1617
#include "facade/dragonfly_connection.h"
1718
#include "facade/error.h"
1819
#include "server/acl/acl_commands_def.h"
@@ -137,6 +138,17 @@ CommandId::CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first
137138
kLatencyHistogramPrecision, &hist);
138139
CHECK_EQ(init_result, 0) << "failed to initialize histogram for command " << name;
139140
latency_histogram_ = hist;
141+
142+
if (name_.rfind("EVAL", 0) == 0)
143+
kind_multi_ctr_ = CO::MultiControlKind::EVAL;
144+
else if (base::_in(name_, {"EXEC", "MULTI", "DISCARD"}))
145+
kind_multi_ctr_ = CO::MultiControlKind::EXEC;
146+
else if (base::_in(name_, {"PUBLISH", "SUBSCRIBE", "UNSUBSCRIBE"}))
147+
kind_pubsub_ = CO::PubSubKind::REGULAR;
148+
else if (base::_in(name_, {"PSUBSCRIBE", "PUNSUBSCRIBE"}))
149+
kind_pubsub_ = CO::PubSubKind::PATTERN;
150+
else if (base::_in(name_, {"SPUBLISH", "SSUBSCRIBE", "SUNSUBSCRIBE"}))
151+
kind_pubsub_ = CO::PubSubKind::SHARDED;
140152
}
141153

142154
CommandId::~CommandId() {
@@ -174,7 +186,7 @@ bool CommandId::IsTransactional() const {
174186
}
175187

176188
bool CommandId::IsMultiTransactional() const {
177-
return CO::IsTransKind(name()) || CO::IsEvalKind(name());
189+
return kind_multi_ctr_.has_value();
178190
}
179191

180192
uint64_t CommandId::Invoke(CmdArgList args, const CommandContext& cmd_cntx) const {

src/server/command_registry.h

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,14 @@ enum CommandOpt : uint32_t {
5555
IDEMPOTENT = 1U << 18,
5656
};
5757

58-
constexpr inline bool IsEvalKind(std::string_view name) {
59-
return name.compare(0, 4, "EVAL") == 0;
60-
}
58+
enum class PubSubKind : uint8_t { REGULAR = 0, PATTERN = 1, SHARDED = 2 };
6159

62-
constexpr inline bool IsTransKind(std::string_view name) {
63-
return (name == "EXEC") || (name == "MULTI") || (name == "DISCARD");
64-
}
65-
66-
static_assert(IsEvalKind("EVAL") && IsEvalKind("EVAL_RO") && IsEvalKind("EVALSHA") &&
67-
IsEvalKind("EVALSHA_RO"));
68-
static_assert(!IsEvalKind(""));
60+
// Commands controlling any multi command execution.
61+
// They often need to be handled separately from regular commands in many contexts
62+
enum class MultiControlKind : uint8_t {
63+
EVAL, // EVAL, EVAL_RO, EVALSHA, EVALSHA_RO
64+
EXEC, // EXEC, MULTI, DISCARD
65+
};
6966

7067
}; // namespace CO
7168

@@ -114,6 +111,7 @@ class CommandId : public facade::CommandId {
114111
// server_state.h)
115112
CommandId(const char* name, uint32_t mask, int8_t arity, int8_t first_key, int8_t last_key,
116113
std::optional<uint32_t> acl_categories = std::nullopt);
114+
117115
CommandId(CommandId&& o) = default;
118116

119117
~CommandId();
@@ -182,7 +180,19 @@ class CommandId : public facade::CommandId {
182180

183181
hdr_histogram* LatencyHist() const;
184182

183+
std::optional<CO::PubSubKind> PubSubKind() const {
184+
return kind_pubsub_;
185+
}
186+
187+
// Returns value if this command controls multi command execution (EVAL, EXEC & helpers)
188+
std::optional<CO::MultiControlKind> MultiControlKind() const {
189+
return kind_multi_ctr_;
190+
}
191+
185192
private:
193+
std::optional<CO::PubSubKind> kind_pubsub_;
194+
std::optional<CO::MultiControlKind> kind_multi_ctr_;
195+
186196
// The following fields must copy manually in the move constructor.
187197
bool implicit_acl_;
188198
bool is_alias_{false};

0 commit comments

Comments
 (0)