Skip to content

Commit 844f089

Browse files
committed
feat: add page-aligned tensor creator for host KV cache.
1 parent d4446aa commit 844f089

File tree

4 files changed

+106
-33
lines changed

4 files changed

+106
-33
lines changed

xllm/core/framework/kv_cache/kv_cache_store.cpp

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,30 +55,18 @@ bool KVCacheStore::init(const StoreConfig& config,
5555
LOG(INFO) << "v_cache_size_per_block: " << v_cache_size_per_block_;
5656

5757
if (config_.protocol == "rdma") {
58-
for (int block = 0; block < host_kv_caches_->size(); block++) {
59-
void* key_cache = static_cast<char*>(
60-
host_kv_caches_->at(block).get_k_cache().data_ptr());
61-
62-
auto register_k_result = client_ptr_->RegisterLocalMemory(
63-
key_cache, k_cache_size_per_block_, "cpu:0", false, false);
64-
65-
if (!register_k_result.has_value()) {
66-
LOG(ERROR) << "Failed to register local memory for key cache: "
67-
<< toString(register_k_result.error());
68-
return false;
69-
}
70-
71-
void* value_cache = static_cast<char*>(
72-
host_kv_caches_->at(block).get_v_cache().data_ptr());
73-
74-
auto register_v_result = client_ptr_->RegisterLocalMemory(
75-
value_cache, v_cache_size_per_block_, "cpu:0", false, false);
76-
77-
if (!register_v_result.has_value()) {
78-
LOG(ERROR) << "Failed to register local memory for value cache: "
79-
<< toString(register_v_result.error());
58+
if (config_.total_size > 0 && config_.tensor_data != nullptr) {
59+
auto result = client_ptr_->RegisterLocalMemory(
60+
config_.tensor_data, config_.total_size, "cpu:0", false, false);
61+
if (!result.has_value()) {
62+
LOG(ERROR) << "Failed to register local memory: "
63+
<< toString(result.error());
8064
return false;
8165
}
66+
} else {
67+
LOG(FATAL) << "rdma must RegisterLocalMemory, but got register size: "
68+
<< config_.total_size
69+
<< ", and data ptr: " << uint64_t(config_.tensor_data);
8270
}
8371
}
8472
is_initialized_ = true;

xllm/core/framework/kv_cache/kv_cache_store.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ struct StoreConfig {
1919
std::string master_server_address = "";
2020
int replica_num = 1;
2121
uint32_t tp_rank = 0;
22+
size_t total_size = 0;
23+
void* tensor_data = nullptr;
2224
};
2325

2426
class KVCacheStore {

xllm/core/runtime/worker_impl.cpp

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,18 +151,9 @@ bool WorkerImpl::allocate_host_kv_cache(
151151
host_kv_cache_shape[1][0] = num_layers;
152152

153153
// create a KVCache shape: block_size * [layers, token, head, dim]
154-
host_kv_caches_.reserve(host_bolck_size);
154+
aligned_tensor_creater_ = std::make_unique<AlignedTensorCreater>(
155+
host_kv_cache_shape, dtype_, host_bolck_size, &host_kv_caches_);
155156

156-
for (int64_t i = 0; i < host_bolck_size; ++i) {
157-
torch::Tensor key_cache, value_cache;
158-
key_cache = torch::empty(host_kv_cache_shape[0],
159-
torch::dtype(dtype_).device(torch::kCPU))
160-
.pin_memory();
161-
value_cache = torch::empty(host_kv_cache_shape[1],
162-
torch::dtype(dtype_).device(torch::kCPU))
163-
.pin_memory();
164-
host_kv_caches_.emplace_back(key_cache, value_cache);
165-
}
166157
LOG(INFO) << "Initializing host kv block size: " << host_bolck_size;
167158

168159
int32_t device_id = device_.index();
@@ -187,6 +178,8 @@ bool WorkerImpl::allocate_host_kv_cache(
187178
config.tp_rank = options_.dp_size() > 1
188179
? options_.node_rank() % options_.dp_size()
189180
: options_.node_rank();
181+
config.total_size = aligned_tensor_creater_->get_total_size();
182+
config.tensor_data = aligned_tensor_creater_->get_base_ptr();
190183

191184
if (!KVCacheStore::get_instance().init(config, &host_kv_caches_)) {
192185
LOG(ERROR) << "Init KVCacheStore fail!";
@@ -1025,4 +1018,68 @@ uint32_t WorkerImpl::prefetch_from_storage(
10251018
.get();
10261019
}
10271020

1021+
AlignedTensorCreater::AlignedTensorCreater(
1022+
const std::vector<std::vector<int64_t>>& tensor_shapes,
1023+
const torch::ScalarType dtype,
1024+
const uint32_t num_tensors,
1025+
std::vector<xllm::KVCache>* tensors) {
1026+
CHECK(tensor_shapes.size() == 2)
1027+
<< "tensor_shapes.size() must equal to 2, but got "
1028+
<< tensor_shapes.size();
1029+
1030+
int64_t elements_per_k_tensor = 1;
1031+
int64_t elements_per_v_tensor = 1;
1032+
1033+
for (auto dim : tensor_shapes[0]) {
1034+
elements_per_k_tensor *= dim;
1035+
}
1036+
for (auto dim : tensor_shapes[1]) {
1037+
elements_per_v_tensor *= dim;
1038+
}
1039+
1040+
size_t element_size = torch::elementSize(dtype);
1041+
size_t bytes_per_k_tensor = elements_per_k_tensor * element_size;
1042+
size_t bytes_per_v_tensor = elements_per_v_tensor * element_size;
1043+
size_t page_size = sysconf(_SC_PAGESIZE);
1044+
total_size_ = num_tensors * (bytes_per_k_tensor + bytes_per_v_tensor);
1045+
total_size_ = ((total_size_ + page_size - 1) / page_size) * page_size;
1046+
1047+
base_ptr_ = mmap(nullptr,
1048+
total_size_,
1049+
PROT_READ | PROT_WRITE,
1050+
MAP_PRIVATE | MAP_ANONYMOUS,
1051+
-1,
1052+
0);
1053+
1054+
if (base_ptr_ == MAP_FAILED) {
1055+
LOG(FATAL) << "Failed to allocate aligned memory pool!";
1056+
}
1057+
1058+
if (mlock(base_ptr_, total_size_) != 0) {
1059+
munmap(base_ptr_, total_size_);
1060+
LOG(FATAL) << "Failed to lock memory pool!";
1061+
}
1062+
1063+
size_t current_offset = 0;
1064+
auto options = torch::TensorOptions().dtype(dtype).device(torch::kCPU);
1065+
tensors->reserve(num_tensors);
1066+
1067+
for (size_t i = 0; i < num_tensors; ++i) {
1068+
void* k_tensor_ptr = static_cast<char*>(base_ptr_) + current_offset;
1069+
torch::Tensor k_tensor =
1070+
torch::from_blob(k_tensor_ptr, tensor_shapes[0], options);
1071+
current_offset += bytes_per_k_tensor;
1072+
1073+
void* v_tensor_ptr = static_cast<char*>(base_ptr_) + current_offset;
1074+
torch::Tensor v_tensor =
1075+
torch::from_blob(v_tensor_ptr, tensor_shapes[1], options);
1076+
current_offset += bytes_per_v_tensor;
1077+
1078+
tensors->emplace_back(k_tensor, v_tensor);
1079+
}
1080+
1081+
LOG(INFO) << "Page aligned: "
1082+
<< ((uintptr_t)base_ptr_ % page_size == 0 ? "YES" : "NO");
1083+
}
1084+
10281085
} // namespace xllm

xllm/core/runtime/worker_impl.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#pragma once
1717

1818
#include <folly/futures/Future.h>
19+
#include <sys/mman.h>
1920
#include <torch/torch.h>
2021

2122
#include <memory>
@@ -45,6 +46,8 @@ limitations under the License.
4546

4647
namespace xllm {
4748

49+
class AlignedTensorCreater;
50+
4851
class WorkerImpl {
4952
public:
5053
enum Status : int8_t {
@@ -237,6 +240,7 @@ class WorkerImpl {
237240
// kv caches
238241
std::vector<xllm::KVCache> kv_caches_;
239242
std::vector<xllm::KVCache> host_kv_caches_;
243+
std::unique_ptr<AlignedTensorCreater> aligned_tensor_creater_;
240244

241245
// causal LM model
242246
std::unique_ptr<CausalLM> model_;
@@ -277,4 +281,26 @@ class WorkerImpl {
277281
layer_wise_load_synchronizer_;
278282
};
279283

284+
class AlignedTensorCreater {
285+
private:
286+
void* base_ptr_;
287+
size_t total_size_;
288+
289+
public:
290+
AlignedTensorCreater(const std::vector<std::vector<int64_t>>& tensor_shapes,
291+
const torch::ScalarType dtype,
292+
const uint32_t num_tensors,
293+
std::vector<xllm::KVCache>* tensors);
294+
295+
~AlignedTensorCreater() {
296+
if (base_ptr_ != nullptr) {
297+
munlock(base_ptr_, total_size_);
298+
munmap(base_ptr_, total_size_);
299+
}
300+
}
301+
302+
void* get_base_ptr() const { return base_ptr_; }
303+
size_t get_total_size() const { return total_size_; }
304+
};
305+
280306
} // namespace xllm

0 commit comments

Comments
 (0)