Skip to content

Commit a09f543

Browse files
committed
refactor: async device/host block copying, remove sync waits.
1 parent eee3ee9 commit a09f543

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1038
-560
lines changed

xllm/core/distributed_runtime/comm_channel.cpp

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -306,21 +306,33 @@ bool CommChannel::allocate_kv_cache_with_transfer(
306306
return true;
307307
}
308308

309-
bool CommChannel::load_kv_blocks_from_store_async(
310-
const std::vector<CacheBlockInfo>& cache_block_info,
309+
void CommChannel::transfer_kv_blocks(
310+
const std::vector<BlockTransferInfo>& block_transfer_info,
311311
folly::Promise<uint32_t>& promise) {
312-
proto::CacheBlockInfos pb_cache_block_info;
313-
if (!cache_block_info_to_proto(cache_block_info, &pb_cache_block_info)) {
312+
proto::BlockTransferInfos pb_block_transfer_info;
313+
if (!block_transfer_info_to_proto(
314+
0x0, block_transfer_info, &pb_block_transfer_info)) {
314315
promise.setValue(0);
315-
return false;
316+
return;
316317
}
317318

318-
auto done = new LoadKVCacheFromStoreClosure();
319+
auto done = new TransferBlocksClosure();
319320
done->promise = std::move(promise);
320-
stub_->LoadKVCacheFromStore(
321-
&done->cntl, &pb_cache_block_info, &done->response, done);
321+
stub_->TransferBlocks(
322+
&done->cntl, &pb_block_transfer_info, &done->response, done);
323+
}
322324

323-
return true;
325+
void CommChannel::transfer_kv_blocks(
326+
const uint64_t batch_id,
327+
const std::vector<BlockTransferInfo>& block_transfer_info) {
328+
proto::BlockTransferInfos pb_block_transfer_info;
329+
if (!block_transfer_info_to_proto(
330+
batch_id, block_transfer_info, &pb_block_transfer_info)) {
331+
return;
332+
}
333+
brpc::Controller cntl;
334+
proto::TransferStatus response;
335+
stub_->TransferBlocks(&cntl, &pb_block_transfer_info, &response, nullptr);
324336
}
325337

326338
bool CommChannel::get_last_step_result_async(
@@ -397,18 +409,6 @@ bool CommChannel::execute_model_with_brpc(
397409
return true;
398410
}
399411

400-
void LoadKVCacheFromStoreClosure::Run() {
401-
std::unique_ptr<LoadKVCacheFromStoreClosure> self_guard(this);
402-
403-
bool success = !cntl.Failed();
404-
if (!success) {
405-
promise.setValue(0);
406-
} else {
407-
promise.setValue(response.success_cnt());
408-
}
409-
return;
410-
}
411-
412412
void ExecuteModelClosure::Run() {
413413
std::unique_ptr<ExecuteModelClosure> self_guard(this);
414414

@@ -437,4 +437,17 @@ void InitModelClosure::Run() {
437437

438438
return;
439439
}
440+
441+
void TransferBlocksClosure::Run() {
442+
std::unique_ptr<TransferBlocksClosure> self_guard(this);
443+
444+
bool success = !cntl.Failed();
445+
if (!success) {
446+
promise.setValue(0);
447+
} else {
448+
promise.setValue(response.success_cnt());
449+
}
450+
return;
451+
}
452+
440453
} // namespace xllm

xllm/core/distributed_runtime/comm_channel.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ class CommChannel {
9191
const std::vector<CacheBlockInfo>& cache_block_info,
9292
folly::Promise<uint32_t>& promise);
9393

94+
virtual void transfer_kv_blocks(
95+
const std::vector<BlockTransferInfo>& block_transfer_info,
96+
folly::Promise<uint32_t>& promise);
97+
98+
virtual void transfer_kv_blocks(
99+
const uint64_t batch_id,
100+
const std::vector<BlockTransferInfo>& block_transfer_info);
101+
94102
virtual bool get_last_step_result_async(
95103
folly::Promise<std::optional<RawForwardOutput>>& promise);
96104

@@ -128,11 +136,11 @@ class ExecuteModelClosure : public google::protobuf::Closure {
128136
folly::Promise<std::optional<RawForwardOutput>> promise;
129137
};
130138

131-
class LoadKVCacheFromStoreClosure : public google::protobuf::Closure {
139+
class TransferBlocksClosure : public google::protobuf::Closure {
132140
public:
133141
void Run();
134142

135-
proto::StoreResponse response;
143+
proto::TransferStatus response;
136144
brpc::Controller cntl;
137145
folly::Promise<uint32_t> promise;
138146
};

xllm/core/distributed_runtime/remote_worker.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,18 +282,30 @@ folly::SemiFuture<bool> RemoteWorker::pull_kv_blocks_async(
282282
return future;
283283
}
284284

285-
folly::SemiFuture<uint32_t> RemoteWorker::load_kv_blocks_from_store_async(
286-
const std::vector<CacheBlockInfo> cache_block_info) {
285+
folly::SemiFuture<uint32_t> RemoteWorker::transfer_kv_blocks(
286+
const std::vector<BlockTransferInfo>& block_transfer_info) {
287287
folly::Promise<uint32_t> promise;
288288
auto future = promise.getSemiFuture();
289-
general_threadpool_.schedule([this,
290-
cache_block_info = std::move(cache_block_info),
291-
promise = std::move(promise)]() mutable {
292-
channel_->load_kv_blocks_from_store_async(cache_block_info, promise);
293-
});
289+
general_threadpool_.schedule(
290+
[this,
291+
block_transfer_info = std::move(block_transfer_info),
292+
promise = std::move(promise)]() mutable {
293+
channel_->transfer_kv_blocks(block_transfer_info, promise);
294+
});
294295
return future;
295296
}
296297

298+
void RemoteWorker::transfer_kv_blocks(
299+
const uint64_t batch_id,
300+
const std::vector<BlockTransferInfo>& block_transfer_info) {
301+
general_threadpool_.schedule(
302+
[this,
303+
batch_id = batch_id,
304+
block_transfer_info = std::move(block_transfer_info)]() mutable {
305+
channel_->transfer_kv_blocks(batch_id, block_transfer_info);
306+
});
307+
}
308+
297309
const torch::Device& RemoteWorker::device() const {
298310
LOG(ERROR) << "RemoteWorker Method device is UnImplemented.";
299311
}

xllm/core/distributed_runtime/remote_worker.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,12 @@ class RemoteWorker : public WorkerClient {
110110
const std::vector<uint64_t>& src_blocks,
111111
const std::vector<uint64_t>& dst_blocks);
112112

113-
virtual folly::SemiFuture<uint32_t> load_kv_blocks_from_store_async(
114-
const std::vector<CacheBlockInfo> cache_block_info);
113+
virtual folly::SemiFuture<uint32_t> transfer_kv_blocks(
114+
const std::vector<BlockTransferInfo>& block_transfer_info) override;
115+
116+
virtual void transfer_kv_blocks(
117+
const uint64_t batch_id,
118+
const std::vector<BlockTransferInfo>& block_transfer_info) override;
115119

116120
// Run the model and return the output.
117121
virtual folly::SemiFuture<std::optional<ForwardOutput>> step_async(

xllm/core/distributed_runtime/worker_service.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -417,18 +417,21 @@ void WorkerService::PullKVCache(::google::protobuf::RpcController* controller,
417417
return;
418418
}
419419

420-
void WorkerService::LoadKVCacheFromStore(
420+
void WorkerService::TransferBlocks(
421421
::google::protobuf::RpcController* controller,
422-
const ::xllm::proto::CacheBlockInfos* req,
423-
::xllm::proto::StoreResponse* resp,
422+
const ::xllm::proto::BlockTransferInfos* req,
423+
::xllm::proto::TransferStatus* resp,
424424
::google::protobuf::Closure* done) {
425425
brpc::ClosureGuard done_guard(done);
426-
std::vector<CacheBlockInfo> dst_blocks;
427-
proto_to_cache_block_info(*req, dst_blocks);
426+
std::vector<BlockTransferInfo> block_transfer_info;
427+
uint64_t batch_id;
428+
proto_to_block_transfer_info(*req, batch_id, block_transfer_info);
428429

429-
auto future = worker_->load_kv_blocks_from_store_async(dst_blocks);
430-
431-
resp->set_success_cnt(std::move(future).get());
430+
if (batch_id == 0x0) {
431+
resp->set_success_cnt(worker_->transfer_kv_blocks(block_transfer_info));
432+
} else {
433+
worker_->transfer_kv_blocks(batch_id, std::move(block_transfer_info));
434+
}
432435
return;
433436
}
434437

xllm/core/distributed_runtime/worker_service.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,10 @@ class WorkerService : public proto::DistributeWorker {
8080
proto::Status* resp,
8181
::google::protobuf::Closure* done) override;
8282

83-
virtual void LoadKVCacheFromStore(
84-
::google::protobuf::RpcController* controller,
85-
const ::xllm::proto::CacheBlockInfos* req,
86-
::xllm::proto::StoreResponse* resp,
87-
::google::protobuf::Closure* done) override;
83+
virtual void TransferBlocks(::google::protobuf::RpcController* controller,
84+
const ::xllm::proto::BlockTransferInfos* req,
85+
::xllm::proto::TransferStatus* resp,
86+
::google::protobuf::Closure* done) override;
8887

8988
void GetDeviceInfo(::google::protobuf::RpcController* controller,
9089
const proto::Empty* req,

xllm/core/framework/batch/batch.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,8 @@ ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens,
7373
allowed_max_tokens_,
7474
input_embeddings_vec_,
7575
mm_data_vec_,
76-
copy_in_cache_block_infos_,
77-
copy_out_cache_block_infos_,
78-
swap_cache_block_infos_,
76+
swap_block_transfer_infos_,
77+
batch_id_,
7978
&args);
8079
return builder.build_forward_input(num_decoding_tokens,
8180
min_decoding_batch_size);
@@ -88,9 +87,8 @@ RawForwardInput Batch::prepare_forward_input(uint32_t start_idx,
8887
allowed_max_tokens_,
8988
input_embeddings_vec_,
9089
mm_data_vec_,
91-
copy_in_cache_block_infos_,
92-
copy_out_cache_block_infos_,
93-
swap_cache_block_infos_,
90+
swap_block_transfer_infos_,
91+
batch_id_,
9492
nullptr,
9593
thread_pool);
9694
return builder.build_raw_forward_input(start_idx, end_idx);

xllm/core/framework/batch/batch.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License.
1616

1717
#pragma once
1818

19+
#include <absl/time/clock.h>
20+
#include <absl/time/time.h>
1921
#include <torch/torch.h>
2022

2123
#include <limits>
@@ -48,20 +50,18 @@ class Batch {
4850
sequence_groups_.push_back(sequence_group);
4951
}
5052

51-
void set_copy_in_cache_block_infos(
52-
std::vector<CacheBlockInfo>* copy_in_cache_block_infos) {
53-
copy_in_cache_block_infos_ = copy_in_cache_block_infos;
53+
void set_swap_block_transfer_infos(
54+
std::vector<BlockTransferInfo>* swap_block_transfer_infos) {
55+
swap_block_transfer_infos_ = swap_block_transfer_infos;
5456
}
5557

56-
void set_copy_out_cache_block_infos(
57-
std::vector<CacheBlockInfo>* copy_out_cache_block_infos) {
58-
copy_out_cache_block_infos_ = copy_out_cache_block_infos;
58+
void set_batch_id() {
59+
if (batch_id_ == 0x0) {
60+
batch_id_ = absl::ToUnixMicros(absl::Now());
61+
}
5962
}
6063

61-
void set_swap_cache_block_infos(
62-
std::vector<CacheBlockInfo>* swap_cache_block_infos) {
63-
swap_cache_block_infos_ = swap_cache_block_infos;
64-
}
64+
uint64_t batch_id() const { return batch_id_; }
6565

6666
// get the number of sequences in the batch
6767
size_t size() const { return sequences_.size(); }
@@ -123,9 +123,7 @@ class Batch {
123123

124124
std::vector<Sequence*> sequences_;
125125
std::vector<SequencesGroup*> sequence_groups_;
126-
std::vector<CacheBlockInfo>* copy_in_cache_block_infos_ = nullptr;
127-
std::vector<CacheBlockInfo>* copy_out_cache_block_infos_ = nullptr;
128-
std::vector<CacheBlockInfo>* swap_cache_block_infos_ = nullptr;
126+
std::vector<BlockTransferInfo>* swap_block_transfer_infos_ = nullptr;
129127

130128
// max number of tokens to process for each sequence
131129
// default to max value
@@ -138,6 +136,8 @@ class Batch {
138136

139137
// all sequences in this batch are in prefill stage
140138
bool all_seqs_in_prefill_ = false;
139+
140+
uint64_t batch_id_ = 0x0;
141141
};
142142

143143
} // namespace xllm

xllm/core/framework/batch/batch_factory.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ std::vector<Batch> BatchFactory::create_batches(
3333
const std::vector<std::shared_ptr<Request>>& running_requests,
3434
const std::vector<Sequence*>& running_sequences,
3535
const std::vector<size_t>& running_sequences_budgets,
36-
std::vector<std::vector<CacheBlockInfo>>* copy_in_cache_block_infos,
37-
std::vector<std::vector<CacheBlockInfo>>* copy_out_cache_block_infos,
38-
std::vector<std::vector<CacheBlockInfo>>* swap_cache_block_infos) {
36+
std::vector<std::vector<BlockTransferInfo>>* swap_block_transfer_infos) {
3937
size_t num_prompt_tokens = 0;
4038
size_t num_generated_tokens = 0;
4139
std::vector<Batch> batches(dp_size_);
@@ -74,19 +72,10 @@ std::vector<Batch> BatchFactory::create_batches(
7472

7573
for (int i = 0; i < dp_size_; i++) {
7674
if (!batches[i].empty()) {
77-
if (copy_in_cache_block_infos != nullptr &&
78-
copy_in_cache_block_infos->size() == dp_size_) {
79-
batches[i].set_copy_in_cache_block_infos(
80-
&(copy_in_cache_block_infos->at(i)));
81-
}
82-
if (copy_out_cache_block_infos != nullptr &&
83-
copy_out_cache_block_infos->size() == dp_size_) {
84-
batches[i].set_copy_out_cache_block_infos(
85-
&(copy_out_cache_block_infos->at(i)));
86-
}
87-
if (swap_cache_block_infos != nullptr &&
88-
swap_cache_block_infos->size() == dp_size_) {
89-
batches[i].set_swap_cache_block_infos(&(swap_cache_block_infos->at(i)));
75+
if (swap_block_transfer_infos != nullptr &&
76+
swap_block_transfer_infos->size() == dp_size_) {
77+
batches[i].set_swap_block_transfer_infos(
78+
&(swap_block_transfer_infos->at(i)));
9079
}
9180
}
9281
}

xllm/core/framework/batch/batch_factory.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,8 @@ class BatchFactory {
3131
const std::vector<std::shared_ptr<Request>>& running_requests,
3232
const std::vector<Sequence*>& running_sequences,
3333
const std::vector<size_t>& running_sequences_budgets,
34-
// for global kv cache copy block from host to device
35-
std::vector<std::vector<CacheBlockInfo>>* copy_in_cache_block_infos =
36-
nullptr,
37-
// for global kv cache copy block from device to host
38-
std::vector<std::vector<CacheBlockInfo>>* copy_out_cache_block_infos =
39-
nullptr,
4034
// for beam-search
41-
std::vector<std::vector<CacheBlockInfo>>* swap_cache_block_infos =
35+
std::vector<std::vector<BlockTransferInfo>>* swap_block_transfer_infos =
4236
nullptr);
4337

4438
private:

0 commit comments

Comments
 (0)