Skip to content

Commit 4f5b84d

Browse files
committed
feat: implement batch prefetch from store.
1 parent 170ec80 commit 4f5b84d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+501
-158
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,18 @@ DEFINE_string(store_protocol,
332332
"tcp",
333333
"KV cache store protocol(e.g. tcp, rdma).");
334334

335-
DEFINE_string(store_master_server_entry,
335+
DEFINE_string(store_master_server_address,
336336
"",
337337
"The address information of the store master service.");
338338

339-
DEFINE_string(store_metadata_connstring,
339+
DEFINE_string(store_metadata_server,
340340
"",
341341
"The address of the kv cache store metadata service.");
342342

343+
DEFINE_string(store_local_hostname,
344+
"",
345+
"The local host name of the kv cache store client.");
346+
343347
// --- computation communication parallel config ---
344348

345349
DEFINE_bool(

xllm/core/common/global_flags.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,11 @@ DECLARE_bool(enable_kvcache_store);
159159

160160
DECLARE_string(store_protocol);
161161

162-
DECLARE_string(store_master_server_entry);
162+
DECLARE_string(store_master_server_address);
163163

164-
DECLARE_string(store_metadata_connstring);
164+
DECLARE_string(store_metadata_server);
165+
166+
DECLARE_string(store_local_hostname);
165167

166168
DECLARE_bool(enable_multi_stream_parallel);
167169

xllm/core/common/options.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ std::string Options::to_string() const {
5353
<< ", enable_cache_upload: " << enable_cache_upload()
5454
<< ", enable_kvcache_store: " << enable_kvcache_store()
5555
<< ", store_protocol: " << store_protocol()
56-
<< ", store_master_server_entry: " << store_master_server_entry()
57-
<< ", store_metadata_connstring: " << store_metadata_connstring()
56+
<< ", store_master_server_address: " << store_master_server_address()
57+
<< ", store_metadata_server: " << store_metadata_server()
58+
<< ", store_local_hostname: " << store_local_hostname()
5859
<< ", enable_multi_stream_parallel: " << enable_multi_stream_parallel()
5960
<< ", enable_continuous_kvcache: " << enable_continuous_kvcache()
6061
<< ", disable_ttft_profiling: " << disable_ttft_profiling()

xllm/core/common/options.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,11 @@ class Options {
143143

144144
PROPERTY(std::string, store_protocol) = "tcp";
145145

146-
PROPERTY(std::string, store_master_server_entry) = "";
146+
PROPERTY(std::string, store_master_server_address) = "";
147147

148-
PROPERTY(std::string, store_metadata_connstring) = "";
148+
PROPERTY(std::string, store_metadata_server) = "";
149+
150+
PROPERTY(std::string, store_local_hostname) = "";
149151

150152
PROPERTY(bool, enable_multi_stream_parallel) = false;
151153

xllm/core/distributed_runtime/comm_channel.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ limitations under the License.
1818
#include <brpc/controller.h>
1919
#include <glog/logging.h>
2020

21+
#include <future>
22+
2123
namespace xllm {
2224

2325
bool CommChannel::init_brpc(const std::string& server_address) {
@@ -335,6 +337,94 @@ void CommChannel::transfer_kv_blocks(
335337
stub_->TransferBlocks(&cntl, &pb_block_transfer_info, &response, nullptr);
336338
}
337339

340+
class ClientStreamReceiver : public brpc::StreamInputHandler {
341+
private:
342+
const std::atomic<bool>& termination_flag_;
343+
std::shared_ptr<std::atomic<uint32_t>> success_cnt_;
344+
std::promise<void> close_promise_;
345+
std::atomic<bool> promise_set_{false};
346+
347+
public:
348+
ClientStreamReceiver(const std::atomic<bool>& termination_flag,
349+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt)
350+
: termination_flag_(termination_flag), success_cnt_(success_cnt) {}
351+
352+
~ClientStreamReceiver() {
353+
if (!promise_set_.exchange(true)) {
354+
try {
355+
close_promise_.set_value();
356+
} catch (const std::exception& e) {
357+
LOG(WARNING) << "Exception in destructor: " << e.what();
358+
}
359+
}
360+
}
361+
362+
std::future<void> get_close_future() { return close_promise_.get_future(); }
363+
364+
int on_received_messages(brpc::StreamId id,
365+
butil::IOBuf* const messages[],
366+
size_t size) override {
367+
for (size_t i = 0; i < size; ++i) {
368+
std::string msg_str = messages[i]->to_string();
369+
int32_t success_cnt = std::stoi(msg_str);
370+
371+
if (success_cnt > 0 &&
372+
!termination_flag_.load(std::memory_order_acquire)) {
373+
success_cnt_->fetch_add(success_cnt, std::memory_order_relaxed);
374+
} else {
375+
brpc::StreamClose(id);
376+
if (!promise_set_.exchange(true)) {
377+
close_promise_.set_value();
378+
}
379+
break;
380+
}
381+
}
382+
return 0;
383+
}
384+
385+
virtual void on_idle_timeout(brpc::StreamId id) override {
386+
if (!promise_set_.exchange(true)) {
387+
close_promise_.set_value();
388+
}
389+
}
390+
391+
virtual void on_closed(brpc::StreamId id) override {
392+
if (!promise_set_.exchange(true)) {
393+
close_promise_.set_value();
394+
}
395+
}
396+
};
397+
398+
void CommChannel::prefetch_from_storage(
399+
const std::atomic<bool>& flag,
400+
const std::vector<BlockTransferInfo>& block_transfer_info,
401+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) {
402+
proto::BlockTransferInfos pb_block_transfer_info;
403+
if (!block_transfer_info_to_proto(
404+
0x0, block_transfer_info, &pb_block_transfer_info)) {
405+
return;
406+
}
407+
ClientStreamReceiver receiver(flag, success_cnt);
408+
brpc::Controller cntl;
409+
brpc::StreamOptions stream_options;
410+
brpc::StreamId stream_id;
411+
proto::Status response;
412+
stream_options.handler = &receiver;
413+
if (brpc::StreamCreate(&stream_id, cntl, &stream_options) != 0) {
414+
LOG(ERROR) << "Failed to create stream";
415+
return;
416+
}
417+
418+
stub_->PrefetchFromStorage(
419+
&cntl, &pb_block_transfer_info, &response, nullptr);
420+
421+
if (cntl.Failed()) {
422+
LOG(ERROR) << "Fail to connect stream, " << cntl.ErrorText();
423+
}
424+
425+
receiver.get_close_future().wait();
426+
}
427+
338428
bool CommChannel::get_last_step_result_async(
339429
folly::Promise<std::optional<RawForwardOutput>>& promise) {
340430
proto::Empty req;

xllm/core/distributed_runtime/comm_channel.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,6 @@ class CommChannel {
8787
const uint64_t kv_cache_size,
8888
const std::vector<std::vector<int64_t>>& kv_cache_shape);
8989

90-
virtual bool load_kv_blocks_from_store_async(
91-
const std::vector<CacheBlockInfo>& cache_block_info,
92-
folly::Promise<uint32_t>& promise);
93-
9490
virtual void transfer_kv_blocks(
9591
const std::vector<BlockTransferInfo>& block_transfer_info,
9692
folly::Promise<uint32_t>& promise);
@@ -99,6 +95,11 @@ class CommChannel {
9995
const uint64_t batch_id,
10096
const std::vector<BlockTransferInfo>& block_transfer_info);
10197

98+
virtual void prefetch_from_storage(
99+
const std::atomic<bool>& flag,
100+
const std::vector<BlockTransferInfo>& block_transfer_info,
101+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt);
102+
102103
virtual bool get_last_step_result_async(
103104
folly::Promise<std::optional<RawForwardOutput>>& promise);
104105

xllm/core/distributed_runtime/remote_worker.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ limitations under the License.
3535
#include "util/hash_util.h"
3636

3737
namespace xllm {
38+
3839
RemoteWorker::RemoteWorker(int32_t global_rank,
3940
const std::string& server_address,
4041
const torch::Device& d,
@@ -286,7 +287,7 @@ folly::SemiFuture<uint32_t> RemoteWorker::transfer_kv_blocks(
286287
const std::vector<BlockTransferInfo>& block_transfer_info) {
287288
folly::Promise<uint32_t> promise;
288289
auto future = promise.getSemiFuture();
289-
general_threadpool_.schedule(
290+
copy_threadpool_.schedule(
290291
[this,
291292
block_transfer_info = std::move(block_transfer_info),
292293
promise = std::move(promise)]() mutable {
@@ -298,14 +299,27 @@ folly::SemiFuture<uint32_t> RemoteWorker::transfer_kv_blocks(
298299
void RemoteWorker::transfer_kv_blocks(
299300
const uint64_t batch_id,
300301
const std::vector<BlockTransferInfo>& block_transfer_info) {
301-
general_threadpool_.schedule(
302+
copy_threadpool_.schedule(
302303
[this,
303304
batch_id = batch_id,
304305
block_transfer_info = std::move(block_transfer_info)]() mutable {
305306
channel_->transfer_kv_blocks(batch_id, block_transfer_info);
306307
});
307308
}
308309

310+
void RemoteWorker::prefetch_from_storage(
311+
const std::atomic<bool>& flag,
312+
const std::vector<BlockTransferInfo>& block_transfer_info,
313+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) {
314+
copy_threadpool_.schedule(
315+
[this,
316+
flag = &flag,
317+
block_transfer_info = std::move(block_transfer_info),
318+
success_cnt = success_cnt]() mutable {
319+
channel_->prefetch_from_storage(flag, block_transfer_info, success_cnt);
320+
});
321+
}
322+
309323
const torch::Device& RemoteWorker::device() const {
310324
LOG(ERROR) << "RemoteWorker Method device is UnImplemented.";
311325
}

xllm/core/distributed_runtime/remote_worker.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ class RemoteWorker : public WorkerClient {
117117
const uint64_t batch_id,
118118
const std::vector<BlockTransferInfo>& block_transfer_info) override;
119119

120+
virtual void prefetch_from_storage(
121+
const std::atomic<bool>& flag,
122+
const std::vector<BlockTransferInfo>& block_transfer_info,
123+
std::shared_ptr<std::atomic<uint32_t>>& success_cnt) override;
124+
120125
// Run the model and return the output.
121126
virtual folly::SemiFuture<std::optional<ForwardOutput>> step_async(
122127
const ForwardInput& inputs) override;
@@ -144,9 +149,8 @@ class RemoteWorker : public WorkerClient {
144149
// connection resource
145150
std::unique_ptr<CommChannel> channel_;
146151
ThreadPool threadpool_;
147-
// general working thread
148-
// do some overlap work with model execute
149-
ThreadPool general_threadpool_{4};
152+
// copy working thread
153+
ThreadPool copy_threadpool_{4};
150154
const torch::Device device_;
151155
};
152156
} // namespace xllm

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,13 +419,12 @@ void WorkerService::PullKVCache(::google::protobuf::RpcController* controller,
419419

420420
void WorkerService::TransferBlocks(
421421
::google::protobuf::RpcController* controller,
422-
const ::xllm::proto::BlockTransferInfos* req,
423-
::xllm::proto::TransferStatus* resp,
422+
const proto::BlockTransferInfos* req,
423+
proto::TransferStatus* resp,
424424
::google::protobuf::Closure* done) {
425425
brpc::ClosureGuard done_guard(done);
426426
std::vector<BlockTransferInfo> block_transfer_info;
427-
uint64_t batch_id;
428-
proto_to_block_transfer_info(*req, batch_id, block_transfer_info);
427+
uint64_t batch_id = proto_to_block_transfer_info(*req, block_transfer_info);
429428

430429
if (batch_id == 0x0) {
431430
resp->set_success_cnt(worker_->transfer_kv_blocks(block_transfer_info));
@@ -435,6 +434,114 @@ void WorkerService::TransferBlocks(
435434
return;
436435
}
437436

437+
class ServerStreamHandler : public brpc::StreamInputHandler {
438+
private:
439+
std::promise<void> close_promise_;
440+
std::atomic<bool> promise_set_{false};
441+
442+
public:
443+
~ServerStreamHandler() {
444+
if (!promise_set_.exchange(true)) {
445+
try {
446+
close_promise_.set_value();
447+
} catch (const std::exception& e) {
448+
LOG(WARNING) << "Exception in destructor: " << e.what();
449+
}
450+
}
451+
}
452+
453+
std::future<void> get_close_future() { return close_promise_.get_future(); }
454+
455+
int on_received_messages(brpc::StreamId id,
456+
butil::IOBuf* const messages[],
457+
size_t size) override {
458+
LOG(WARNING) << "ServerStreamHandler::on_received_messages not implement.";
459+
return 0;
460+
}
461+
462+
void on_closed(brpc::StreamId id) override {
463+
if (!promise_set_.exchange(true)) {
464+
close_promise_.set_value();
465+
}
466+
}
467+
468+
void on_idle_timeout(brpc::StreamId id) override {
469+
if (!promise_set_.exchange(true)) {
470+
LOG(WARNING) << "Stream idle timeout: " << id;
471+
close_promise_.set_value();
472+
}
473+
}
474+
};
475+
476+
void WorkerService::PrefetchFromStorage(
477+
google::protobuf::RpcController* controller,
478+
const proto::BlockTransferInfos* req,
479+
proto::Status* resp,
480+
google::protobuf::Closure* done) {
481+
brpc::ClosureGuard done_guard(done);
482+
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
483+
484+
auto stream_handler = std::make_unique<ServerStreamHandler>();
485+
auto stream_id = std::make_unique<brpc::StreamId>();
486+
brpc::StreamOptions stream_options;
487+
stream_options.handler = stream_handler.get();
488+
if (brpc::StreamAccept(stream_id.get(), *cntl, &stream_options) != 0) {
489+
resp->set_ok(false);
490+
LOG(ERROR) << "Failed to accept stream!";
491+
return;
492+
}
493+
494+
std::vector<BlockTransferInfo> block_transfer_info;
495+
proto_to_block_transfer_info(*req, block_transfer_info);
496+
497+
copy_threadpool_.schedule(
498+
[this,
499+
block_transfer_info = std::move(block_transfer_info),
500+
stream_id = std::move(stream_id),
501+
stream_handler = std::move(stream_handler)]() mutable {
502+
Slice<BlockTransferInfo> transfer_slice{block_transfer_info};
503+
auto close_future = stream_handler->get_close_future();
504+
bool is_completed = false;
505+
506+
for (size_t i = 0; i < transfer_slice.size();
507+
i += stream_copy_batch_size_) {
508+
auto current_slice = transfer_slice.slice(
509+
i, std::min(i + stream_copy_batch_size_, transfer_slice.size()));
510+
511+
auto success_cnt = worker_->prefetch_from_storage(current_slice);
512+
513+
if (success_cnt != current_slice.size() ||
514+
i + stream_copy_batch_size_ >= transfer_slice.size()) {
515+
is_completed = true;
516+
}
517+
518+
butil::IOBuf buf;
519+
buf.append(std::to_string(success_cnt));
520+
if (brpc::StreamWrite(*stream_id.get(), buf) != 0) {
521+
brpc::StreamClose(*stream_id.get());
522+
is_completed = false;
523+
break;
524+
}
525+
526+
if (is_completed) {
527+
if (success_cnt != 0) {
528+
butil::IOBuf buf_end;
529+
buf_end.append("0");
530+
brpc::StreamWrite(*stream_id.get(), buf_end);
531+
}
532+
break;
533+
}
534+
}
535+
if (is_completed) {
536+
close_future.wait();
537+
}
538+
brpc::StreamClose(*stream_id.get());
539+
});
540+
541+
resp->set_ok(true);
542+
return;
543+
}
544+
438545
void WorkerService::GetDeviceInfo(::google::protobuf::RpcController* controller,
439546
const proto::Empty* req,
440547
proto::DeviceInfo* resp,

0 commit comments

Comments
 (0)