diff --git a/src/server/conn_context.h b/src/server/conn_context.h index 2ad04341e937..86100572da9a 100644 --- a/src/server/conn_context.h +++ b/src/server/conn_context.h @@ -154,6 +154,7 @@ struct ConnectionState { std::string repl_ip_address; uint32_t repl_listening_port = 0; DflyVersion repl_version = DflyVersion::VER1; + bool is_valkey = false; }; struct SquashingInfo { diff --git a/src/server/dflycmd.cc b/src/server/dflycmd.cc index 9ee5b912a96c..4b5583a58f0b 100644 --- a/src/server/dflycmd.cc +++ b/src/server/dflycmd.cc @@ -649,6 +649,238 @@ void DflyCmd::Load(CmdArgList args, RedisReplyBuilder* rb, ConnectionContext* cn rb->SendOk(); } +namespace { + +struct ShardJournalChannel : journal::JournalConsumerInterface { + explicit ShardJournalChannel(fb2::EventCount& e, journal::Journal* journal) + : ec{e}, reader_{nullptr, 0}, journal_{journal} { + CHECK(journal); + journal_cb_id = journal_->RegisterOnChange(this); + } + + void Stop() { + journal_->UnregisterOnChange(journal_cb_id); + } + + void ConsumeJournalChange(const journal::JournalChangeItem& item) override { + if (rpos == wpos) { + rpos = 0; + wpos = 0; + buffer.clear(); + buffer.shrink_to_fit(); + } + + auto data_view = io::BytesSource{item.journal_item.data}; + reader_.SetSource(&data_view); + auto res = reader_.ReadEntry(); + CHECK(res.has_value()); + + auto& pe = res.value(); + auto opcode = pe.opcode; + + if (pe.cmd.cmd_len != 0) { + buffer.emplace_back(std::move(pe)); + wpos++; + } + + ec.notifyAll(); + } + + void ThrottleIfNeeded() override { + } + + std::vector Read() { + CHECK_LT(rpos, wpos) << "Invalid read attempt"; + + auto i = rpos; + std::vector result; + while (i < wpos) { + result.emplace_back(std::move(buffer[i++])); + } + rpos = i; + return result; + } + + bool HasData() const { + return rpos < wpos; + } + + fb2::EventCount& ec; + size_t rpos{0}; + size_t wpos{0}; + + JournalReader reader_; + std::vector buffer; + + uint32_t journal_cb_id; + journal::Journal* journal_; +}; + +struct Pipe final : io::Source, io::Sink { + io::Result ReadSome(const iovec* v, uint32_t len) override { + if (done) { + return 0; + } + + ec.await([&] { return rpos < wpos || done; }); + if (done && rpos == wpos) { + return 0; + } + + auto bytes_read = 0; + + while (rpos < wpos && len > 0) { + const auto chunk_size = min(wpos - rpos, v->iov_len); + std::copy_n(buffer.begin() + rpos, chunk_size, static_cast(v->iov_base)); + bytes_read += chunk_size; + rpos += chunk_size; + ++v; + --len; + } + + if (rpos == wpos && wpos == cap) { + rpos = 0; + wpos = 0; + ec.notifyAll(); + } + + return bytes_read; + } + + io::Result WriteSome(const iovec* v, uint32_t len) override { + CHECK(!done); + ec.await([&] { return wpos < cap || done; }); + if (done && wpos == cap) { + return 0; + } + + int bytes_written = 0; + + while (wpos < cap && len > 0) { + const auto chunk_size = std::min(cap - wpos, v->iov_len); + auto p = static_cast(v->iov_base); + std::copy_n(p, chunk_size, buffer.begin() + wpos); + bytes_written += chunk_size; + wpos += chunk_size; + ++v; + --len; + } + + ec.notifyAll(); + return bytes_written; + } + + void Stop() { + done = true; + ec.notifyAll(); + } + + std::array buffer; + size_t rpos{0}; + size_t wpos{0}; + size_t cap{1024}; + std::atomic_bool done{false}; + fb2::EventCount ec; +}; + +} // namespace + +void DflyCmd::StartValkeySync() { + auto Write = [this](auto v) { + const auto buf = io::Bytes(reinterpret_cast(v.data()), v.size()); + CHECK(!_valkey_replica->conn->socket()->Write(buf)); + }; + + CHECK(_valkey_replica.has_value()) << "There is no valkey replica to sync with"; + + // Since we do not know the size of rdb up front, use the EOF protocol, send + // "$EOF:<40-random-chars>\n" first, then the same 40 chars at the end + std::string eof_mark(40, 'X'); + std::string eof_mark_with_prefix = absl::StrCat("$EOF:", eof_mark, "\n"); + + Write(eof_mark_with_prefix); + + for (unsigned i = 0; i < shard_set->size(); ++i) { + Pipe p; + auto cb = [&] { + std::array backing; + const io::MutableBytes mb{backing}; + while (!p.done) { + auto n = p.Read(mb); + if (!n.has_value() || n.value() == 0) { + break; + } + CHECK(!_valkey_replica->conn->socket()->Write(mb.subspan(0, n.value()))); + } + + if (auto n = p.Read(mb); n.has_value() && n.value()) { + CHECK(!_valkey_replica->conn->socket()->Write(mb.subspan(0, n.value()))); + } + }; + auto drain_fb = fb2::Fiber("replica-drain-fb", cb); + + shard_set->Await(i, [&p, this, i] { + auto shard = EngineShard::tlocal(); + RdbSaver saver{&p, SaveMode::SINGLE_SHARD, false, ""}; + if (i == 0) { + CHECK(!saver.SaveHeader(saver.GetGlobalData(&sf_->service()))); + } + + saver.StartSnapshotInShard(false, &_valkey_replica->exec_st, shard); + bool skip_epilog = i < shard_set->size() - 1; + CHECK(!saver.WaitSnapshotInShard(shard, skip_epilog)); + p.Stop(); + VLOG(1) << "finished writing snapshot for shard " << shard->shard_id(); + }); + + drain_fb.JoinIfNeeded(); + } + + Write(eof_mark); + + // Stable sync + VLOG(1) << "Entering stable sync.."; + + std::vector> channels(shard_set->size()); + fb2::EventCount ec; + JournalReader reader{nullptr, 0}; + + auto cb = [&channels, &ec, this](EngineShard* shard) { + auto& channel = channels[shard->shard_id()]; + sf_->journal()->StartInThread(); + channel.reset(new ShardJournalChannel(ec, sf_->journal())); + VLOG(1) << "Set channel for shard " << shard->shard_id(); + }; + shard_set->RunBlockingInParallel(cb); + + RedisReplyBuilder rb{_valkey_replica->conn->socket()}; + DbIndex current_dbid = std::numeric_limits::max(); + + while (true) { + ec.await([&channels] { + return std::any_of(channels.begin(), channels.end(), + [](const auto& channel) { return channel->HasData(); }); + }); + for (const auto& channel : channels) { + if (channel->HasData()) { + for (auto& entry : channel->Read()) { + if (entry.dbid != current_dbid) { + VLOG(1) << "Database changed from " << current_dbid << " to " << entry.dbid; + std::string entry_dbid = std::to_string(entry.dbid); + std::vector select_cmd = {"SELECT", entry_dbid}; + + VLOG(1) << "sending command: " << select_cmd; + rb.SendBulkStrArr(select_cmd); + current_dbid = entry.dbid; + } + VLOG(1) << "sending command: " << entry.ToString() << " of size " << entry.cmd.cmd_len; + rb.SendBulkStrArr(entry.cmd.cmd_args); + } + } + } + } +} + OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, ExecutionState* exec_st, EngineShard* shard) { DCHECK(shard); @@ -730,6 +962,12 @@ void DflyCmd::StartStableSyncInThread(FlowInfo* flow, ExecutionState* exec_st, E }; } +void DflyCmd::CreateValkeySyncSession(facade::Connection* conn) { + CHECK(!_valkey_replica.has_value()); + fb2::LockGuard lk(mu_); + _valkey_replica.emplace(conn, [](const GenericError&) {}); +} + auto DflyCmd::CreateSyncSession(ConnectionState* state) -> std::pair { util::fb2::LockGuard lk(mu_); unsigned sync_id = next_sync_id_++; diff --git a/src/server/dflycmd.h b/src/server/dflycmd.h index 530d176eea8d..6b2ef99b0b1c 100644 --- a/src/server/dflycmd.h +++ b/src/server/dflycmd.h @@ -129,6 +129,13 @@ class DflyCmd { util::fb2::SharedMutex shared_mu; // See top of header for locking levels. }; + struct ValkeyReplica { + ValkeyReplica(facade::Connection* conn, ExecutionState::ErrHandler h) : conn{conn}, exec_st{h} { + } + facade::Connection* conn = nullptr; + ExecutionState exec_st; + }; + public: DflyCmd(ServerFamily* server_family); @@ -142,6 +149,7 @@ class DflyCmd { // Create new sync session. Returns (session_id, number of flows) std::pair CreateSyncSession(ConnectionState* state) ABSL_LOCKS_EXCLUDED(mu_); + void CreateValkeySyncSession(facade::Connection* conn); // Master side access method to replication info of that connection. std::shared_ptr GetReplicaInfoFromConnection(ConnectionState* state); @@ -156,6 +164,7 @@ class DflyCmd { // Tries to break those flows that stuck on socket write for too long time. void BreakStalledFlowsInShard() ABSL_NO_THREAD_SAFETY_ANALYSIS; + void StartValkeySync(); private: using RedisReplyBuilder = facade::RedisReplyBuilder; @@ -238,6 +247,8 @@ class DflyCmd { using ReplicaInfoMap = absl::btree_map>; ReplicaInfoMap replica_infos_ ABSL_GUARDED_BY(mu_); + std::optional _valkey_replica = std::nullopt; + mutable util::fb2::Mutex mu_; // Guard global operations. See header top for locking levels. }; diff --git a/src/server/rdb_save.cc b/src/server/rdb_save.cc index edbbac0e66e1..e3f5e6f548d3 100644 --- a/src/server/rdb_save.cc +++ b/src/server/rdb_save.cc @@ -1449,8 +1449,13 @@ void RdbSaver::StartSnapshotInShard(bool stream_journal, ExecutionState* cntx, E impl_->StartSnapshotting(stream_journal, cntx, shard); } -error_code RdbSaver::WaitSnapshotInShard(EngineShard* shard) { +error_code RdbSaver::WaitSnapshotInShard(EngineShard* shard, bool skip_epilog) { impl_->WaitForSnapshottingFinish(shard); + if (skip_epilog) { + RETURN_ON_ERR(impl_->FlushSerializer()); + return impl_->FlushSink(); + } + return SaveEpilog(); } diff --git a/src/server/rdb_save.h b/src/server/rdb_save.h index 71d6e444d896..fa7df72c27b3 100644 --- a/src/server/rdb_save.h +++ b/src/server/rdb_save.h @@ -99,7 +99,7 @@ class RdbSaver { std::error_code StopFullSyncInShard(EngineShard* shard); // Wait for snapshotting finish in shard thread. Called from save flows in shard thread. - std::error_code WaitSnapshotInShard(EngineShard* shard); + std::error_code WaitSnapshotInShard(EngineShard* shard, bool skip_epilog = false); // Stores auxiliary (meta) values and header_info std::error_code SaveHeader(const GlobalData& header_info); diff --git a/src/server/server_family.cc b/src/server/server_family.cc index 84f3bc6896ea..c91f57b9294d 100644 --- a/src/server/server_family.cc +++ b/src/server/server_family.cc @@ -3769,6 +3769,14 @@ void ServerFamily::ReplTakeOver(CmdArgList args, const CommandContext& cmd_cntx) return builder->SendOk(); } +void ServerFamily::PSync(CmdArgList args, const CommandContext& cmd_cntx) { + auto* rb = static_cast(cmd_cntx.rb); + auto response = absl::StrFormat("FULLRESYNC %s %ld", master_replid_, 0); + rb->SendSimpleString(response); + + dfly_cmd_->StartValkeySync(); +} + void ServerFamily::ReplConf(CmdArgList args, const CommandContext& cmd_cntx) { auto* builder = cmd_cntx.rb; { @@ -3854,6 +3862,9 @@ void ServerFamily::ReplConf(CmdArgList args, const CommandContext& cmd_cntx) { VLOG(2) << "Received client ACK=" << ack; cntx->replication_flow->last_acked_lsn = ack; return; + } else if (cmd == "VERSION" && args.size() == 2) { + cntx->conn_state.replication_info.is_valkey = true; + dfly_cmd_->CreateValkeySyncSession(cntx->conn()); } else { VLOG(1) << "Error " << cmd << " " << arg << " " << args.size(); return err_cb(); @@ -4159,7 +4170,8 @@ void ServerFamily::Register(CommandRegistry* registry) { << CI{"SLOWLOG", CO::ADMIN | CO::FAST, -2, 0, 0, acl::kSlowLog}.HFUNC(SlowLog) << CI{"SCRIPT", CO::NOSCRIPT | CO::NO_KEY_TRANSACTIONAL, -2, 0, 0, acl::kScript}.HFUNC(Script) << CI{"DFLY", CO::ADMIN | CO::GLOBAL_TRANS | CO::HIDDEN, -2, 0, 0, acl::kDfly}.HFUNC(Dfly) - << CI{"MODULE", CO::ADMIN, 2, 0, 0, acl::kModule}.HFUNC(Module); + << CI{"MODULE", CO::ADMIN, 2, 0, 0, acl::kModule}.HFUNC(Module) + << CI{"PSYNC", CO::ADMIN | CO::GLOBAL_TRANS, -2, 0, 0, acl::kDfly}.HFUNC(PSync); } } // namespace dfly diff --git a/src/server/server_family.h b/src/server/server_family.h index 0a06ab011314..3434b202dcf1 100644 --- a/src/server/server_family.h +++ b/src/server/server_family.h @@ -363,6 +363,7 @@ class ServerFamily { void Script(CmdArgList args, const CommandContext& cmd_cntx); void SlowLog(CmdArgList args, const CommandContext& cmd_cntx); void Module(CmdArgList args, const CommandContext& cmd_cntx); + void PSync(CmdArgList args, const CommandContext& cmd_cntx); void SyncGeneric(std::string_view repl_master_id, uint64_t offs, ConnectionContext* cntx);