Skip to content

Commit 9f6d444

Browse files
committed
refactor: change host KV cache memory layout from layer-wise to block-wise.
1 parent 40db395 commit 9f6d444

File tree

5 files changed

+140
-129
lines changed

5 files changed

+140
-129
lines changed

xllm/core/framework/kv_cache/kv_cache_store.cpp

Lines changed: 31 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,24 @@ bool KVCacheStore::init(const StoreConfig& config,
4343
}
4444
client_ptr_ = client_opt.value();
4545

46-
auto key_tensor_one_layer = host_kv_caches_->at(0).get_k_cache();
47-
auto value_tensor_one_layer = host_kv_caches_->at(0).get_v_cache();
46+
auto k_tensor_one_block = host_kv_caches_->at(0).get_k_cache();
47+
auto v_tensor_one_block = host_kv_caches_->at(0).get_v_cache();
4848

49-
key_cache_size_per_layer_ =
50-
key_tensor_one_layer[0].numel() * key_tensor_one_layer[0].element_size();
51-
value_cache_size_per_layer_ = value_tensor_one_layer[0].numel() *
52-
value_tensor_one_layer[0].element_size();
49+
k_cache_size_per_block_ =
50+
k_tensor_one_block.numel() * k_tensor_one_block.element_size();
51+
v_cache_size_per_block_ =
52+
v_tensor_one_block.numel() * v_tensor_one_block.element_size();
5353

54-
auto key_cache_host_size =
55-
key_tensor_one_layer.numel() * key_tensor_one_layer.element_size();
56-
auto value_cache_host_size =
57-
value_tensor_one_layer.numel() * value_tensor_one_layer.element_size();
58-
59-
LOG(INFO) << "key_cache_size_per_layer: " << key_cache_size_per_layer_;
60-
LOG(INFO) << "value_cache_size_per_layer: " << value_cache_size_per_layer_;
54+
LOG(INFO) << "k_cache_size_per_block: " << k_cache_size_per_block_;
55+
LOG(INFO) << "v_cache_size_per_block: " << v_cache_size_per_block_;
6156

6257
if (config_.protocol == "rdma") {
63-
for (int layer = 0; layer < host_kv_caches_->size(); layer++) {
58+
for (int block = 0; block < host_kv_caches_->size(); block++) {
6459
void* key_cache = static_cast<char*>(
65-
host_kv_caches_->at(layer).get_k_cache().data_ptr());
60+
host_kv_caches_->at(block).get_k_cache().data_ptr());
6661

6762
auto register_k_result = client_ptr_->RegisterLocalMemory(
68-
key_cache, key_cache_host_size, "cpu:0", false, false);
63+
key_cache, k_cache_size_per_block_, "cpu:0", false, false);
6964

7065
if (!register_k_result.has_value()) {
7166
LOG(ERROR) << "Failed to register local memory for key cache: "
@@ -74,10 +69,10 @@ bool KVCacheStore::init(const StoreConfig& config,
7469
}
7570

7671
void* value_cache = static_cast<char*>(
77-
host_kv_caches_->at(layer).get_v_cache().data_ptr());
72+
host_kv_caches_->at(block).get_v_cache().data_ptr());
7873

7974
auto register_v_result = client_ptr_->RegisterLocalMemory(
80-
value_cache, value_cache_host_size, "cpu:0", false, false);
75+
value_cache, v_cache_size_per_block_, "cpu:0", false, false);
8176

8277
if (!register_v_result.has_value()) {
8378
LOG(ERROR) << "Failed to register local memory for value cache: "
@@ -119,23 +114,14 @@ uint32_t KVCacheStore::batch_put(
119114

120115
str_keys.emplace_back(str_key);
121116

122-
std::vector<mooncake::Slice> slice;
123-
slice.reserve(host_kv_caches_->size() * 2);
124-
for (int layer = 0; layer < host_kv_caches_->size(); layer++) {
125-
void* key_cache =
126-
static_cast<char*>(
127-
host_kv_caches_->at(layer).get_k_cache().data_ptr()) +
128-
block_info.dst_block_id * key_cache_size_per_layer_;
129-
slice.emplace_back(mooncake::Slice{key_cache, key_cache_size_per_layer_});
130-
131-
void* value_cache =
132-
static_cast<char*>(
133-
host_kv_caches_->at(layer).get_v_cache().data_ptr()) +
134-
block_info.dst_block_id * value_cache_size_per_layer_;
135-
slice.emplace_back(
136-
mooncake::Slice{value_cache, value_cache_size_per_layer_});
137-
}
138-
slices.emplace_back(std::move(slice));
117+
void* k_cache =
118+
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
119+
void* v_cache =
120+
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
121+
122+
slices.emplace_back(std::vector<mooncake::Slice>{
123+
mooncake::Slice{k_cache, k_cache_size_per_block_},
124+
mooncake::Slice{v_cache, v_cache_size_per_block_}});
139125
}
140126

141127
if (str_keys.size() == 0) {
@@ -177,24 +163,16 @@ uint32_t KVCacheStore::batch_get(
177163

178164
str_keys.emplace_back(str_key);
179165

180-
slices.insert(std::make_pair(str_key, std::vector<mooncake::Slice>()));
181-
182-
slices[str_key].reserve(host_kv_caches_->size() * 2);
183-
for (int layer = 0; layer < host_kv_caches_->size(); layer++) {
184-
void* key_cache =
185-
static_cast<char*>(
186-
host_kv_caches_->at(layer).get_k_cache().data_ptr()) +
187-
block_info.dst_block_id * key_cache_size_per_layer_;
188-
slices[str_key].emplace_back(
189-
mooncake::Slice{key_cache, key_cache_size_per_layer_});
190-
191-
void* value_cache =
192-
static_cast<char*>(
193-
host_kv_caches_->at(layer).get_v_cache().data_ptr()) +
194-
block_info.dst_block_id * value_cache_size_per_layer_;
195-
slices[str_key].emplace_back(
196-
mooncake::Slice{value_cache, value_cache_size_per_layer_});
197-
}
166+
void* k_cache =
167+
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
168+
void* v_cache =
169+
host_kv_caches_->at(block_info.dst_block_id).get_k_cache().data_ptr();
170+
171+
slices.insert(
172+
std::make_pair(str_key,
173+
std::vector<mooncake::Slice>{
174+
mooncake::Slice{k_cache, k_cache_size_per_block_},
175+
mooncake::Slice{v_cache, v_cache_size_per_block_}}));
198176
}
199177

200178
if (str_keys.size() == 0) {

xllm/core/framework/kv_cache/kv_cache_store.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ class KVCacheStore {
6969

7070
std::vector<xllm::KVCache>* host_kv_caches_;
7171

72-
uint64_t key_cache_size_per_layer_;
73-
uint64_t value_cache_size_per_layer_;
72+
uint64_t k_cache_size_per_block_;
73+
uint64_t v_cache_size_per_block_;
7474

7575
std::shared_ptr<mooncake::Client> client_ptr_;
7676
};

xllm/core/framework/request/sequence.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,10 @@ class Sequence final {
239239

240240
void sync_result() {
241241
if (futures_.has_value()) {
242-
auto success_cnt = host_kv_state_.num_kv_blocks();
242+
uint32_t success_cnt = host_kv_state_.num_kv_blocks();
243243
for (auto& future : futures_.value()) {
244244
if (future.isReady()) {
245-
success_cnt = std::min(success_cnt, size_t(future.value()));
245+
success_cnt = std::min(success_cnt, future.value());
246246
} else {
247247
return;
248248
}

xllm/core/runtime/worker_impl.cpp

Lines changed: 101 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,19 @@ bool WorkerImpl::allocate_host_kv_cache(
131131

132132
CHECK(model_ != nullptr) << "Model is not initialized.";
133133
CHECK(host_kv_caches_.empty()) << "KV caches are already initialized.";
134+
CHECK(device_kv_cache_shape[0][0] == device_kv_cache_shape[1][0]);
134135

135136
std::vector<std::vector<int64_t>> host_kv_cache_shape = device_kv_cache_shape;
136-
host_kv_cache_shape[0][0] =
137+
const int64_t num_layers = context_.get_model_args().n_layers();
138+
int64_t host_bolck_size =
137139
device_kv_cache_shape[0][0] * options_.host_blocks_factor();
138-
host_kv_cache_shape[1][0] =
139-
device_kv_cache_shape[1][0] * options_.host_blocks_factor();
140+
host_kv_cache_shape[0][0] = num_layers;
141+
host_kv_cache_shape[1][0] = num_layers;
140142

141-
// create a KVCache for each layer
142-
const int64_t num_layers = context_.get_model_args().n_layers();
143-
host_kv_caches_.reserve(num_layers);
144-
for (int64_t i = 0; i < num_layers; ++i) {
143+
// create a KVCache shape: block_size * [layers, token, head, dim]
144+
host_kv_caches_.reserve(host_bolck_size);
145+
146+
for (int64_t i = 0; i < host_bolck_size; ++i) {
145147
torch::Tensor key_cache, value_cache;
146148
key_cache = torch::empty(host_kv_cache_shape[0],
147149
torch::dtype(dtype_).device(torch::kCPU))
@@ -151,8 +153,7 @@ bool WorkerImpl::allocate_host_kv_cache(
151153
.pin_memory();
152154
host_kv_caches_.emplace_back(key_cache, value_cache);
153155
}
154-
LOG(INFO) << "Initializing host k cache size: " << host_kv_cache_shape[0][0];
155-
LOG(INFO) << "Initializing host v cache size: " << host_kv_cache_shape[1][0];
156+
LOG(INFO) << "Initializing host kv block size: " << host_bolck_size;
156157

157158
int32_t device_id = device_.index();
158159
h2d_attrs_.dstLoc.id = device_id;
@@ -687,22 +688,8 @@ uint32_t WorkerImpl::transfer_kv_blocks(
687688

688689
switch (block_transfer_info[0].transfer_type) {
689690
case TransferType::G2H: {
690-
folly::Promise<uint32_t> promise;
691-
auto future = promise.getSemiFuture();
692-
693-
batchget_threadpool_.schedule(
694-
[this, &block_transfer_info, promise = std::move(promise)]() mutable {
695-
promise.setValue(
696-
KVCacheStore::get_instance().batch_get(block_transfer_info));
697-
});
698-
699-
try {
700-
auto timeout = std::chrono::seconds(KVSTORE_TIMEOUT);
701-
return std::move(future).wait(timeout);
702-
} catch (const folly::FutureTimeout& e) {
703-
LOG(WARNING) << "BatchGet operation timed out";
704-
return 0;
705-
}
691+
Slice<BlockTransferInfo> info_slice{block_transfer_info};
692+
return load_from_store(info_slice);
706693
}
707694
case TransferType::D2G:
708695
return offload_kv_blocks(block_transfer_info);
@@ -792,23 +779,7 @@ uint32_t WorkerImpl::offload_kv_blocks(
792779
promise = std::move(promise),
793780
slice = std::move(slice)]() mutable {
794781
bool ret = d2h_batch_copy(slice);
795-
uint32_t success_cnt = 0;
796-
797-
folly::Promise<uint32_t> store_promise;
798-
auto future = store_promise.getSemiFuture();
799-
800-
batchput_threadpool_.schedule(
801-
[this, &slice, promise = std::move(store_promise)]() mutable {
802-
promise.setValue(KVCacheStore::get_instance().batch_put(slice));
803-
});
804-
805-
try {
806-
auto timeout = std::chrono::seconds(KVSTORE_TIMEOUT);
807-
success_cnt = std::move(future).wait(timeout);
808-
} catch (const folly::FutureTimeout& e) {
809-
LOG(WARNING) << "BatchPut operation timed out";
810-
}
811-
782+
auto success_cnt = offload_to_store(slice);
812783
if (success_cnt != slice.size()) {
813784
LOG(WARNING) << "KVCacheStore not all put success: " << success_cnt
814785
<< "/" << slice.size();
@@ -894,6 +865,7 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
894865
#if defined(USE_NPU)
895866
CHECK(copy_stream_.count(std::this_thread::get_id()) != 0)
896867
<< "WorkerImpl::d2h_batch_copy can only be called in copy_threadpool_.";
868+
897869
const int64_t num_layers = context_.get_model_args().n_layers();
898870
uint32_t num_batches = block_transfer_info.size() * num_layers * 2;
899871
void** srcs = new void*[num_batches];
@@ -903,26 +875,25 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
903875
size_t attrs_indexes[1] = {0};
904876
size_t fail_index;
905877
uint32_t curr_index = 0;
906-
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
907-
auto src_k_cache = kv_caches_.at(layer_id).get_k_cache();
908-
auto dst_k_cache = host_kv_caches_.at(layer_id).get_k_cache();
909-
auto src_v_cache = kv_caches_.at(layer_id).get_v_cache();
910-
auto dst_v_cache = host_kv_caches_.at(layer_id).get_v_cache();
911-
912-
for (int idx = 0; idx < block_transfer_info.size(); idx++) {
913-
srcs[curr_index] =
914-
src_k_cache[block_transfer_info[idx].src_block_id].data_ptr();
915-
dsts[curr_index] =
916-
dst_k_cache[block_transfer_info[idx].dst_block_id].data_ptr();
917878

879+
for (const auto& info : block_transfer_info) {
880+
auto dst_k_cache = host_kv_caches_.at(info.dst_block_id).get_k_cache();
881+
auto dst_v_cache = host_kv_caches_.at(info.dst_block_id).get_v_cache();
882+
883+
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
884+
auto src_k_cache = kv_caches_.at(layer_id).get_k_cache();
885+
auto src_v_cache = kv_caches_.at(layer_id).get_v_cache();
886+
887+
srcs[curr_index] = src_k_cache[info.src_block_id].data_ptr();
888+
dsts[curr_index] = dst_k_cache[layer_id].data_ptr();
918889
copy_size[curr_index] = key_cache_size_per_layer_;
890+
919891
curr_index++;
920892

921-
srcs[curr_index] =
922-
src_v_cache[block_transfer_info[idx].src_block_id].data_ptr();
923-
dsts[curr_index] =
924-
dst_v_cache[block_transfer_info[idx].dst_block_id].data_ptr();
893+
srcs[curr_index] = src_v_cache[info.src_block_id].data_ptr();
894+
dsts[curr_index] = dst_v_cache[layer_id].data_ptr();
925895
copy_size[curr_index] = value_cache_size_per_layer_;
896+
926897
curr_index++;
927898
}
928899
}
@@ -960,6 +931,7 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
960931
#if defined(USE_NPU)
961932
CHECK(copy_stream_.count(std::this_thread::get_id()) != 0)
962933
<< "WorkerImpl::h2d_batch_copy can only be called in copy_threadpool_.";
934+
963935
const int64_t num_layers = context_.get_model_args().n_layers();
964936
uint32_t num_batches = block_transfer_info.size() * num_layers * 2;
965937
void** srcs = new void*[num_batches];
@@ -970,24 +942,21 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
970942
size_t fail_index;
971943
uint32_t curr_index = 0;
972944

973-
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
974-
auto src_k_cache = host_kv_caches_.at(layer_id).get_k_cache();
975-
auto dst_k_cache = kv_caches_.at(layer_id).get_k_cache();
976-
auto src_v_cache = host_kv_caches_.at(layer_id).get_v_cache();
977-
auto dst_v_cache = kv_caches_.at(layer_id).get_v_cache();
978-
979-
for (int idx = 0; idx < block_transfer_info.size(); idx++) {
980-
srcs[curr_index] =
981-
src_k_cache[block_transfer_info[idx].src_block_id].data_ptr();
982-
dsts[curr_index] =
983-
dst_k_cache[block_transfer_info[idx].dst_block_id].data_ptr();
945+
for (const auto& info : block_transfer_info) {
946+
auto src_k_cache = host_kv_caches_.at(info.src_block_id).get_k_cache();
947+
auto src_v_cache = host_kv_caches_.at(info.src_block_id).get_v_cache();
948+
949+
for (int layer_id = 0; layer_id < num_layers; layer_id++) {
950+
auto dst_k_cache = kv_caches_.at(layer_id).get_k_cache();
951+
auto dst_v_cache = kv_caches_.at(layer_id).get_v_cache();
952+
953+
srcs[curr_index] = src_k_cache[layer_id].data_ptr();
954+
dsts[curr_index] = dst_k_cache[info.dst_block_id].data_ptr();
984955
copy_size[curr_index] = key_cache_size_per_layer_;
985956
curr_index++;
986957

987-
srcs[curr_index] =
988-
src_v_cache[block_transfer_info[idx].src_block_id].data_ptr();
989-
dsts[curr_index] =
990-
dst_v_cache[block_transfer_info[idx].dst_block_id].data_ptr();
958+
srcs[curr_index] = src_v_cache[layer_id].data_ptr();
959+
dsts[curr_index] = dst_v_cache[info.dst_block_id].data_ptr();
991960
copy_size[curr_index] = value_cache_size_per_layer_;
992961
curr_index++;
993962
}
@@ -1021,4 +990,64 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
1021990
return false;
1022991
}
1023992

993+
uint32_t WorkerImpl::offload_to_store(
994+
Slice<BlockTransferInfo>& block_transfer_info) {
995+
if (!options_.enable_kvcache_store()) {
996+
return block_transfer_info.size();
997+
}
998+
999+
folly::Promise<uint32_t> promise;
1000+
auto future = promise.getSemiFuture();
1001+
1002+
batchput_threadpool_.schedule(
1003+
[this, &block_transfer_info, promise = std::move(promise)]() mutable {
1004+
promise.setValue(
1005+
KVCacheStore::get_instance().batch_put(block_transfer_info));
1006+
});
1007+
1008+
auto timeout = std::chrono::seconds(KVSTORE_TIMEOUT);
1009+
return std::move(future)
1010+
.via(folly::getGlobalCPUExecutor())
1011+
.within(timeout)
1012+
.thenTry([](folly::Try<uint32_t>&& t) -> uint32_t {
1013+
if (t.hasValue()) {
1014+
return t.value();
1015+
} else {
1016+
LOG(WARNING) << "BatchPut operation timed out";
1017+
return 0u;
1018+
}
1019+
})
1020+
.get();
1021+
}
1022+
1023+
uint32_t WorkerImpl::load_from_store(
1024+
Slice<BlockTransferInfo>& block_transfer_info) {
1025+
if (!options_.enable_kvcache_store()) {
1026+
return 0;
1027+
}
1028+
1029+
folly::Promise<uint32_t> promise;
1030+
auto future = promise.getSemiFuture();
1031+
1032+
batchget_threadpool_.schedule(
1033+
[this, &block_transfer_info, promise = std::move(promise)]() mutable {
1034+
promise.setValue(
1035+
KVCacheStore::get_instance().batch_get(block_transfer_info));
1036+
});
1037+
1038+
auto timeout = std::chrono::seconds(KVSTORE_TIMEOUT);
1039+
return std::move(future)
1040+
.via(folly::getGlobalCPUExecutor())
1041+
.within(timeout)
1042+
.thenTry([](folly::Try<uint32_t>&& t) -> uint32_t {
1043+
if (t.hasValue()) {
1044+
return t.value();
1045+
} else {
1046+
LOG(WARNING) << "BatchGet operation timed out";
1047+
return 0u;
1048+
}
1049+
})
1050+
.get();
1051+
}
1052+
10241053
} // namespace xllm

0 commit comments

Comments
 (0)