Skip to content

Commit ce547a6

Browse files
authored
feat: allow migrating client connections via CLIENT MIGRATE (#5551)
* feat: allow migrating client connections via CLIENT MIGRATE --------- Signed-off-by: Roman Gershman <[email protected]>
1 parent f058c4e commit ce547a6

File tree

7 files changed

+116
-57
lines changed

7 files changed

+116
-57
lines changed

src/facade/dragonfly_connection.cc

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,9 @@ void Connection::OnPreMigrateThread() {
707707
// is marked beforehand.
708708
migration_in_process_ = true;
709709

710+
// Mark as not owned by any thread as it going through the dark hole
711+
self_.reset();
712+
710713
socket_->CancelOnErrorCb();
711714
DCHECK(!async_fb_.IsJoinable()) << GetClientId();
712715
}
@@ -719,6 +722,7 @@ void Connection::OnPostMigrateThread() {
719722
socket_->RegisterOnErrorCb([this](int32_t mask) { this->OnBreakCb(mask); });
720723
}
721724
migration_in_process_ = false;
725+
self_ = {make_shared<std::monostate>(), this}; // Recreate shared_ptr to self.
722726
DCHECK(!async_fb_.IsJoinable());
723727

724728
// If someone had sent Async during the migration, we must create async_fb_.
@@ -1820,7 +1824,7 @@ bool Connection::Migrate(util::fb2::ProactorBase* dest) {
18201824
Connection::WeakRef Connection::Borrow() {
18211825
DCHECK(self_);
18221826

1823-
return WeakRef(self_, socket_->proactor()->GetPoolIndex(), id_);
1827+
return {self_, unsigned(socket_->proactor()->GetPoolIndex()), id_};
18241828
}
18251829

18261830
void Connection::ShutdownThreadLocal() {
@@ -2003,8 +2007,8 @@ facade::ConnectionContext* Connection::cntx() {
20032007
return cc_.get();
20042008
}
20052009

2006-
void Connection::RequestAsyncMigration(util::fb2::ProactorBase* dest) {
2007-
if (!migration_enabled_ || cc_ == nullptr) {
2010+
void Connection::RequestAsyncMigration(util::fb2::ProactorBase* dest, bool force) {
2011+
if ((!force && !migration_enabled_) || cc_ == nullptr) {
20082012
return;
20092013
}
20102014

@@ -2189,22 +2193,17 @@ void Connection::SetPipelineLowBoundStats(unsigned limit) {
21892193

21902194
Connection::WeakRef::WeakRef(std::shared_ptr<Connection> ptr, unsigned thread_id,
21912195
uint32_t client_id)
2192-
: ptr_{std::move(ptr)}, thread_id_{thread_id}, client_id_{client_id} {
2193-
}
2194-
2195-
unsigned Connection::WeakRef::Thread() const {
2196-
return thread_id_;
2196+
: ptr_{std::move(ptr)}, last_known_thread_id_{thread_id}, client_id_{client_id} {
21972197
}
21982198

21992199
Connection* Connection::WeakRef::Get() const {
2200-
// We should never access the connection object from other threads.
2201-
DCHECK_EQ(ProactorBase::me()->GetPoolIndex(), int(thread_id_));
2200+
auto sptr = ptr_.lock();
22022201

22032202
// The connection can only be deleted on this thread, so
22042203
// this pointer is valid until the next suspension.
22052204
// Note: keeping a shared_ptr doesn't prolong the lifetime because
22062205
// it doesn't manage the underlying connection. See definition of `self_`.
2207-
return ptr_.lock().get();
2206+
return sptr.get();
22082207
}
22092208

22102209
bool Connection::WeakRef::IsExpired() const {

src/facade/dragonfly_connection.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class Connection : public util::Connection {
7070

7171
// PubSub message, either incoming message for active subscription or reply for new subscription.
7272
struct PubMessage {
73-
std::string pattern{}; // non-empty for pattern subscriber
73+
std::string pattern; // non-empty for pattern subscriber
7474
std::shared_ptr<char[]> buf; // stores channel name and message
7575
std::string_view channel, message; // channel and message parts from buf
7676
bool should_unsubscribe = false; // unsubscribe from channel after sending the message
@@ -176,15 +176,16 @@ class Connection : public util::Connection {
176176
static_assert(sizeof(MessageHandle) <= 80,
177177
"Big structs should use indirection to avoid wasting deque space!");
178178

179-
enum Phase { SETUP, READ_SOCKET, PROCESS, SHUTTING_DOWN, PRECLOSE, NUM_PHASES };
179+
enum Phase : uint8_t { SETUP, READ_SOCKET, PROCESS, SHUTTING_DOWN, PRECLOSE, NUM_PHASES };
180180

181181
// Weak reference to a connection, invalidated upon connection close.
182182
// Used to dispatch async operations for the connection without worrying about pointer lifetime.
183183
struct WeakRef {
184184
public:
185185
// Get residing thread of connection. Thread-safe.
186-
unsigned Thread() const;
187-
186+
unsigned LastKnownThreadId() const {
187+
return last_known_thread_id_;
188+
}
188189
// Get pointer to connection if still valid, nullptr if expired.
189190
// Can only be called from connection's thread. Validity is guaranteed
190191
// only until the next suspension point.
@@ -205,7 +206,7 @@ class Connection : public util::Connection {
205206
WeakRef(std::shared_ptr<Connection> ptr, unsigned thread_id, uint32_t client_id);
206207

207208
std::weak_ptr<Connection> ptr_;
208-
unsigned thread_id_;
209+
unsigned last_known_thread_id_;
209210
uint32_t client_id_;
210211
};
211212

@@ -288,8 +289,9 @@ class Connection : public util::Connection {
288289
ConnectionContext* cntx();
289290

290291
// Requests that at some point, this connection will be migrated to `dest` thread.
291-
// Connections will migrate at most once, and only when the flag --migrate_connections is true.
292-
void RequestAsyncMigration(util::fb2::ProactorBase* dest);
292+
// If force is false, the connection will migrate at most once,
293+
// and only when the flag --migrate_connections is true.
294+
void RequestAsyncMigration(util::fb2::ProactorBase* dest, bool force);
293295

294296
// Starts traffic logging in the calling thread. Must be a proactor thread.
295297
// Each thread creates its own log file combining requests from all the connections in
@@ -335,7 +337,7 @@ class Connection : public util::Connection {
335337
std::unique_ptr<ConnectionContext> cc_; // Null for http connections
336338

337339
private:
338-
enum ParserStatus { OK, NEED_MORE, ERROR };
340+
enum ParserStatus : uint8_t { OK, NEED_MORE, ERROR };
339341

340342
struct AsyncOperations;
341343

@@ -443,6 +445,7 @@ class Connection : public util::Connection {
443445

444446
uint32_t id_;
445447
Protocol protocol_;
448+
Phase phase_ = SETUP;
446449

447450
struct {
448451
size_t read_cnt = 0; // total number of read calls
@@ -459,7 +462,6 @@ class Connection : public util::Connection {
459462
ServiceInterface* service_;
460463

461464
time_t creation_time_, last_interaction_;
462-
Phase phase_ = SETUP;
463465
std::string name_;
464466

465467
std::string lib_name_;

src/server/channel_store.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ auto BuildSender(string_view channel, facade::ArgRange messages, bool unsubscrib
4848
} // namespace
4949

5050
bool ChannelStore::Subscriber::ByThread(const Subscriber& lhs, const Subscriber& rhs) {
51-
return ByThreadId(lhs, rhs.Thread());
51+
return ByThreadId(lhs, rhs.LastKnownThreadId());
5252
}
5353

5454
bool ChannelStore::Subscriber::ByThreadId(const Subscriber& lhs, const unsigned thread) {
55-
return lhs.Thread() < thread;
55+
return lhs.LastKnownThreadId() < thread;
5656
}
5757

5858
ChannelStore::UpdatablePointer::UpdatablePointer(const UpdatablePointer& other) {
@@ -128,7 +128,7 @@ unsigned ChannelStore::SendMessages(std::string_view channel, facade::ArgRange m
128128
int32_t last_thread = -1;
129129

130130
for (auto& sub : subscribers) {
131-
int sub_thread = sub.Thread();
131+
int sub_thread = sub.LastKnownThreadId();
132132
DCHECK_LE(last_thread, sub_thread);
133133
if (last_thread == sub_thread) // skip same thread
134134
continue;
@@ -139,15 +139,15 @@ unsigned ChannelStore::SendMessages(std::string_view channel, facade::ArgRange m
139139
// Make sure the connection thread has enough memory budget to accept the message.
140140
// This is a heuristic and not entirely hermetic since the connection memory might
141141
// get filled again.
142-
facade::Connection::EnsureMemoryBudget(sub.Thread());
142+
facade::Connection::EnsureMemoryBudget(sub_thread);
143143
last_thread = sub_thread;
144144
}
145145

146146
auto subscribers_ptr = make_shared<decltype(subscribers)>(std::move(subscribers));
147147
auto cb = [subscribers_ptr, send = BuildSender(channel, messages)](unsigned idx, auto*) {
148148
auto it = lower_bound(subscribers_ptr->begin(), subscribers_ptr->end(), idx,
149149
ChannelStore::Subscriber::ByThreadId);
150-
while (it != subscribers_ptr->end() && it->Thread() == idx) {
150+
while (it != subscribers_ptr->end() && it->LastKnownThreadId() == idx) {
151151
if (auto* ptr = it->Get(); ptr && ptr->cntx() != nullptr)
152152
send(ptr, it->pattern);
153153
it++;
@@ -227,7 +227,7 @@ void ChannelStore::UnsubscribeConnectionsFromDeletedSlots(const ChannelsSubMap&
227227

228228
auto it = lower_bound(subscribers.begin(), subscribers.end(), idx,
229229
ChannelStore::Subscriber::ByThreadId);
230-
while (it != subscribers.end() && it->Thread() == idx) {
230+
while (it != subscribers.end() && it->LastKnownThreadId() == idx) {
231231
// if ptr->cntx() is null, a connection might have closed or be in the process of closing
232232
if (auto* ptr = it->Get(); ptr && ptr->cntx() != nullptr) {
233233
DCHECK(it->pattern.empty());

src/server/db_slice.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,19 +1569,20 @@ void DbSlice::QueueInvalidationTrackingMessageAtomic(std::string_view key) {
15691569
auto [pend_it, inserted] = pending_send_map_.emplace(key, std::move(moved_set));
15701570
if (!inserted) {
15711571
ConnectionHashSet& client_set = pend_it->second;
1572-
for (auto& client : moved_set) {
1573-
client_set.insert(client);
1572+
for (auto& weak_ref : moved_set) {
1573+
client_set.insert(weak_ref);
15741574
}
15751575
}
15761576
}
15771577

1578-
void DbSlice::SendQueuedInvalidationMessagesCb(const TrackingMap& track_map, unsigned idx) const {
1578+
void DbSlice::SendQueuedInvalidationMessagesCb(const TrackingMap& track_map,
1579+
unsigned calling_thread_id) const {
15791580
for (auto& [key, client_list] : track_map) {
1580-
for (auto& client : client_list) {
1581-
if (client.IsExpired() || (client.Thread() != idx)) {
1582-
continue;
1581+
for (auto& weak_ref : client_list) {
1582+
if (weak_ref.IsExpired() || (weak_ref.LastKnownThreadId() != calling_thread_id)) {
1583+
continue; // Expired or migrated.
15831584
}
1584-
auto* conn = client.Get();
1585+
auto* conn = weak_ref.Get();
15851586
auto* cntx = static_cast<ConnectionContext*>(conn->cntx());
15861587
if (cntx && cntx->conn_state.tracking_info_.IsTrackingOn()) {
15871588
conn->SendInvalidationMessageAsync({key});
@@ -1597,8 +1598,8 @@ void DbSlice::SendQueuedInvalidationMessages() {
15971598
// Notify all the clients. this function is not efficient,
15981599
// because it broadcasts to all threads unrelated to the subscribers for the key.
15991600
auto local_map = std::move(pending_send_map_);
1600-
auto cb = [&](unsigned idx, util::ProactorBase*) {
1601-
SendQueuedInvalidationMessagesCb(local_map, idx);
1601+
auto cb = [&](unsigned thread_id, util::ProactorBase*) {
1602+
SendQueuedInvalidationMessagesCb(local_map, thread_id);
16021603
};
16031604

16041605
shard_set->pool()->AwaitBrief(std::move(cb));

src/server/main_service.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,8 +1378,7 @@ OpResult<void> OpTrackKeys(const OpArgs slice_args, const facade::Connection::We
13781378
return OpStatus::OK;
13791379
}
13801380

1381-
DVLOG(2) << "Start tracking keys for client ID: " << conn_ref.GetClientId()
1382-
<< " with thread ID: " << conn_ref.Thread();
1381+
DVLOG(2) << "Start tracking keys for client ID: " << conn_ref.GetClientId();
13831382

13841383
auto& db_slice = slice_args.GetDbSlice();
13851384
// TODO: There is a bug here that we track all arguments instead of tracking only keys.
@@ -2125,7 +2124,7 @@ void Service::EvalInternal(CmdArgList args, const EvalArgs& eval_args, Interpret
21252124
if (*sid != ServerState::tlocal()->thread_index()) {
21262125
VLOG(2) << "Migrating connection " << cntx->conn() << " from "
21272126
<< ProactorBase::me()->GetPoolIndex() << " to " << *sid;
2128-
cntx->conn()->RequestAsyncMigration(shard_set->pool()->at(*sid));
2127+
cntx->conn()->RequestAsyncMigration(shard_set->pool()->at(*sid), false);
21292128
}
21302129
} else {
21312130
Transaction::MultiMode script_mode = DetermineMultiMode(*params);

src/server/server_family.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,49 @@ void ClientKill(CmdArgList args, absl::Span<facade::Listener*> listeners, SinkRe
596596
}
597597
}
598598

599+
void ClientMigrate(CmdArgList args, absl::Span<facade::Listener*> listeners,
600+
SinkReplyBuilder* builder, ConnectionContext* cntx) {
601+
if (args.size() != 2) {
602+
return builder->SendError(kSyntaxErr);
603+
}
604+
605+
uint32_t id;
606+
if (!absl::SimpleAtoi(args[0], &id)) {
607+
return builder->SendError("Invalid client id");
608+
}
609+
610+
uint32_t tid = 0;
611+
if (!absl::SimpleAtoi(args[1], &tid) || tid >= shard_set->pool()->size()) {
612+
return builder->SendError("Invalid thread id");
613+
}
614+
615+
unsigned migrated = 0;
616+
auto cb_brief = [&](unsigned current_tid, ProactorBase* p) {
617+
if (current_tid == tid) {
618+
return; // we should not migrate to the same thread
619+
}
620+
621+
auto traverse_cb = [&](unsigned, util::Connection* conn) {
622+
facade::Connection* dconn = static_cast<facade::Connection*>(conn);
623+
if (dconn->GetClientId() == id) {
624+
++migrated;
625+
dconn->RequestAsyncMigration(shard_set->pool()->at(tid), true /* force */);
626+
}
627+
};
628+
629+
for (auto* listener : listeners) {
630+
if (listener->IsPrivilegedInterface())
631+
continue; // skip privileged interfaces
632+
633+
listener->TraverseConnectionsOnThread(traverse_cb, UINT32_MAX, nullptr);
634+
}
635+
};
636+
637+
shard_set->pool()->AwaitBrief(cb_brief);
638+
639+
return builder->SendLong(migrated);
640+
}
641+
599642
std::string_view GetOSString() {
600643
// Call uname() only once since it can be expensive. Cache the final result in a static string.
601644
static string os_string = []() {
@@ -2096,6 +2139,8 @@ void ClientHelp(SinkReplyBuilder* builder) {
20962139
" * LIB-VER: the client lib version.",
20972140
"TRACKING (ON|OFF) [OPTIN] [OPTOUT] [NOLOOP]",
20982141
" Control server assisted client side caching.",
2142+
"MIGRATE <client-id> <tid>",
2143+
" Migrates connection specified by client-id to the specified thread id.",
20992144
"HELP",
21002145
" Print this help."};
21012146
auto* rb = static_cast<RedisReplyBuilder*>(builder);
@@ -2128,6 +2173,8 @@ void ServerFamily::Client(CmdArgList args, const CommandContext& cmd_cntx) {
21282173
return ClientSetInfo(sub_args, builder, cntx);
21292174
} else if (sub_cmd == "ID") {
21302175
return ClientId(sub_args, builder, cntx);
2176+
} else if (sub_cmd == "MIGRATE") {
2177+
return ClientMigrate(sub_args, absl::MakeSpan(listeners_), builder, cntx);
21312178
} else if (sub_cmd == "HELP") {
21322179
return ClientHelp(builder);
21332180
}

0 commit comments

Comments
 (0)