@@ -151,18 +151,9 @@ bool WorkerImpl::allocate_host_kv_cache(
151151 host_kv_cache_shape[1 ][0 ] = num_layers;
152152
153153 // create a KVCache shape: block_size * [layers, token, head, dim]
154- host_kv_caches_.reserve (host_bolck_size);
154+ aligned_tensor_creater_ = std::make_unique<AlignedTensorCreater>(
155+ host_kv_cache_shape, dtype_, host_bolck_size, &host_kv_caches_);
155156
156- for (int64_t i = 0 ; i < host_bolck_size; ++i) {
157- torch::Tensor key_cache, value_cache;
158- key_cache = torch::empty (host_kv_cache_shape[0 ],
159- torch::dtype (dtype_).device (torch::kCPU ))
160- .pin_memory ();
161- value_cache = torch::empty (host_kv_cache_shape[1 ],
162- torch::dtype (dtype_).device (torch::kCPU ))
163- .pin_memory ();
164- host_kv_caches_.emplace_back (key_cache, value_cache);
165- }
166157 LOG (INFO) << " Initializing host kv block size: " << host_bolck_size;
167158
168159 int32_t device_id = device_.index ();
@@ -187,6 +178,8 @@ bool WorkerImpl::allocate_host_kv_cache(
187178 config.tp_rank = options_.dp_size () > 1
188179 ? options_.node_rank () % options_.dp_size ()
189180 : options_.node_rank ();
181+ config.total_size = aligned_tensor_creater_->get_total_size ();
182+ config.tensor_data = aligned_tensor_creater_->get_base_ptr ();
190183
191184 if (!KVCacheStore::get_instance ().init (config, &host_kv_caches_)) {
192185 LOG (ERROR) << " Init KVCacheStore fail!" ;
@@ -1025,4 +1018,68 @@ uint32_t WorkerImpl::prefetch_from_storage(
10251018 .get ();
10261019}
10271020
1021+ AlignedTensorCreater::AlignedTensorCreater (
1022+ const std::vector<std::vector<int64_t >>& tensor_shapes,
1023+ const torch::ScalarType dtype,
1024+ const uint32_t num_tensors,
1025+ std::vector<xllm::KVCache>* tensors) {
1026+ CHECK (tensor_shapes.size () == 2 )
1027+ << " tensor_shapes.size() must equal to 2, but got "
1028+ << tensor_shapes.size ();
1029+
1030+ int64_t elements_per_k_tensor = 1 ;
1031+ int64_t elements_per_v_tensor = 1 ;
1032+
1033+ for (auto dim : tensor_shapes[0 ]) {
1034+ elements_per_k_tensor *= dim;
1035+ }
1036+ for (auto dim : tensor_shapes[1 ]) {
1037+ elements_per_v_tensor *= dim;
1038+ }
1039+
1040+ size_t element_size = torch::elementSize (dtype);
1041+ size_t bytes_per_k_tensor = elements_per_k_tensor * element_size;
1042+ size_t bytes_per_v_tensor = elements_per_v_tensor * element_size;
1043+ size_t page_size = sysconf (_SC_PAGESIZE);
1044+ total_size_ = num_tensors * (bytes_per_k_tensor + bytes_per_v_tensor);
1045+ total_size_ = ((total_size_ + page_size - 1 ) / page_size) * page_size;
1046+
1047+ base_ptr_ = mmap (nullptr ,
1048+ total_size_,
1049+ PROT_READ | PROT_WRITE,
1050+ MAP_PRIVATE | MAP_ANONYMOUS,
1051+ -1 ,
1052+ 0 );
1053+
1054+ if (base_ptr_ == MAP_FAILED) {
1055+ LOG (FATAL) << " Failed to allocate aligned memory pool!" ;
1056+ }
1057+
1058+ if (mlock (base_ptr_, total_size_) != 0 ) {
1059+ munmap (base_ptr_, total_size_);
1060+ LOG (FATAL) << " Failed to lock memory pool!" ;
1061+ }
1062+
1063+ size_t current_offset = 0 ;
1064+ auto options = torch::TensorOptions ().dtype (dtype).device (torch::kCPU );
1065+ tensors->reserve (num_tensors);
1066+
1067+ for (size_t i = 0 ; i < num_tensors; ++i) {
1068+ void * k_tensor_ptr = static_cast <char *>(base_ptr_) + current_offset;
1069+ torch::Tensor k_tensor =
1070+ torch::from_blob (k_tensor_ptr, tensor_shapes[0 ], options);
1071+ current_offset += bytes_per_k_tensor;
1072+
1073+ void * v_tensor_ptr = static_cast <char *>(base_ptr_) + current_offset;
1074+ torch::Tensor v_tensor =
1075+ torch::from_blob (v_tensor_ptr, tensor_shapes[1 ], options);
1076+ current_offset += bytes_per_v_tensor;
1077+
1078+ tensors->emplace_back (k_tensor, v_tensor);
1079+ }
1080+
1081+ LOG (INFO) << " Page aligned: "
1082+ << ((uintptr_t )base_ptr_ % page_size == 0 ? " YES" : " NO" );
1083+ }
1084+
10281085} // namespace xllm
0 commit comments