@@ -131,17 +131,19 @@ bool WorkerImpl::allocate_host_kv_cache(
131131
132132 CHECK (model_ != nullptr ) << " Model is not initialized." ;
133133 CHECK (host_kv_caches_.empty ()) << " KV caches are already initialized." ;
134+ CHECK (device_kv_cache_shape[0 ][0 ] == device_kv_cache_shape[1 ][0 ]);
134135
135136 std::vector<std::vector<int64_t >> host_kv_cache_shape = device_kv_cache_shape;
136- host_kv_cache_shape[0 ][0 ] =
137+ const int64_t num_layers = context_.get_model_args ().n_layers ();
138+ int64_t host_bolck_size =
137139 device_kv_cache_shape[0 ][0 ] * options_.host_blocks_factor ();
138- host_kv_cache_shape[1 ][0 ] =
139- device_kv_cache_shape [1 ][0 ] * options_. host_blocks_factor () ;
140+ host_kv_cache_shape[0 ][0 ] = num_layers;
141+ host_kv_cache_shape [1 ][0 ] = num_layers ;
140142
141- // create a KVCache for each layer
142- const int64_t num_layers = context_. get_model_args (). n_layers ( );
143- host_kv_caches_. reserve (num_layers);
144- for (int64_t i = 0 ; i < num_layers ; ++i) {
143+ // create a KVCache shape: block_size * [layers, token, head, dim]
144+ host_kv_caches_. reserve (host_bolck_size );
145+
146+ for (int64_t i = 0 ; i < host_bolck_size ; ++i) {
145147 torch::Tensor key_cache, value_cache;
146148 key_cache = torch::empty (host_kv_cache_shape[0 ],
147149 torch::dtype (dtype_).device (torch::kCPU ))
@@ -151,8 +153,7 @@ bool WorkerImpl::allocate_host_kv_cache(
151153 .pin_memory ();
152154 host_kv_caches_.emplace_back (key_cache, value_cache);
153155 }
154- LOG (INFO) << " Initializing host k cache size: " << host_kv_cache_shape[0 ][0 ];
155- LOG (INFO) << " Initializing host v cache size: " << host_kv_cache_shape[1 ][0 ];
156+ LOG (INFO) << " Initializing host kv block size: " << host_bolck_size;
156157
157158 int32_t device_id = device_.index ();
158159 h2d_attrs_.dstLoc .id = device_id;
@@ -687,22 +688,8 @@ uint32_t WorkerImpl::transfer_kv_blocks(
687688
688689 switch (block_transfer_info[0 ].transfer_type ) {
689690 case TransferType::G2H: {
690- folly::Promise<uint32_t > promise;
691- auto future = promise.getSemiFuture ();
692-
693- batchget_threadpool_.schedule (
694- [this , &block_transfer_info, promise = std::move (promise)]() mutable {
695- promise.setValue (
696- KVCacheStore::get_instance ().batch_get (block_transfer_info));
697- });
698-
699- try {
700- auto timeout = std::chrono::seconds (KVSTORE_TIMEOUT);
701- return std::move (future).wait (timeout);
702- } catch (const folly::FutureTimeout& e) {
703- LOG (WARNING) << " BatchGet operation timed out" ;
704- return 0 ;
705- }
691+ Slice<BlockTransferInfo> info_slice{block_transfer_info};
692+ return load_from_store (info_slice);
706693 }
707694 case TransferType::D2G:
708695 return offload_kv_blocks (block_transfer_info);
@@ -792,23 +779,7 @@ uint32_t WorkerImpl::offload_kv_blocks(
792779 promise = std::move (promise),
793780 slice = std::move (slice)]() mutable {
794781 bool ret = d2h_batch_copy (slice);
795- uint32_t success_cnt = 0 ;
796-
797- folly::Promise<uint32_t > store_promise;
798- auto future = store_promise.getSemiFuture ();
799-
800- batchput_threadpool_.schedule (
801- [this , &slice, promise = std::move (store_promise)]() mutable {
802- promise.setValue (KVCacheStore::get_instance ().batch_put (slice));
803- });
804-
805- try {
806- auto timeout = std::chrono::seconds (KVSTORE_TIMEOUT);
807- success_cnt = std::move (future).wait (timeout);
808- } catch (const folly::FutureTimeout& e) {
809- LOG (WARNING) << " BatchPut operation timed out" ;
810- }
811-
782+ auto success_cnt = offload_to_store (slice);
812783 if (success_cnt != slice.size ()) {
813784 LOG (WARNING) << " KVCacheStore not all put success: " << success_cnt
814785 << " /" << slice.size ();
@@ -894,6 +865,7 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
894865#if defined(USE_NPU)
895866 CHECK (copy_stream_.count (std::this_thread::get_id ()) != 0 )
896867 << " WorkerImpl::d2h_batch_copy can only be called in copy_threadpool_." ;
868+
897869 const int64_t num_layers = context_.get_model_args ().n_layers ();
898870 uint32_t num_batches = block_transfer_info.size () * num_layers * 2 ;
899871 void ** srcs = new void *[num_batches];
@@ -903,26 +875,25 @@ bool WorkerImpl::d2h_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
903875 size_t attrs_indexes[1 ] = {0 };
904876 size_t fail_index;
905877 uint32_t curr_index = 0 ;
906- for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
907- auto src_k_cache = kv_caches_.at (layer_id).get_k_cache ();
908- auto dst_k_cache = host_kv_caches_.at (layer_id).get_k_cache ();
909- auto src_v_cache = kv_caches_.at (layer_id).get_v_cache ();
910- auto dst_v_cache = host_kv_caches_.at (layer_id).get_v_cache ();
911-
912- for (int idx = 0 ; idx < block_transfer_info.size (); idx++) {
913- srcs[curr_index] =
914- src_k_cache[block_transfer_info[idx].src_block_id ].data_ptr ();
915- dsts[curr_index] =
916- dst_k_cache[block_transfer_info[idx].dst_block_id ].data_ptr ();
917878
879+ for (const auto & info : block_transfer_info) {
880+ auto dst_k_cache = host_kv_caches_.at (info.dst_block_id ).get_k_cache ();
881+ auto dst_v_cache = host_kv_caches_.at (info.dst_block_id ).get_v_cache ();
882+
883+ for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
884+ auto src_k_cache = kv_caches_.at (layer_id).get_k_cache ();
885+ auto src_v_cache = kv_caches_.at (layer_id).get_v_cache ();
886+
887+ srcs[curr_index] = src_k_cache[info.src_block_id ].data_ptr ();
888+ dsts[curr_index] = dst_k_cache[layer_id].data_ptr ();
918889 copy_size[curr_index] = key_cache_size_per_layer_;
890+
919891 curr_index++;
920892
921- srcs[curr_index] =
922- src_v_cache[block_transfer_info[idx].src_block_id ].data_ptr ();
923- dsts[curr_index] =
924- dst_v_cache[block_transfer_info[idx].dst_block_id ].data_ptr ();
893+ srcs[curr_index] = src_v_cache[info.src_block_id ].data_ptr ();
894+ dsts[curr_index] = dst_v_cache[layer_id].data_ptr ();
925895 copy_size[curr_index] = value_cache_size_per_layer_;
896+
926897 curr_index++;
927898 }
928899 }
@@ -960,6 +931,7 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
960931#if defined(USE_NPU)
961932 CHECK (copy_stream_.count (std::this_thread::get_id ()) != 0 )
962933 << " WorkerImpl::h2d_batch_copy can only be called in copy_threadpool_." ;
934+
963935 const int64_t num_layers = context_.get_model_args ().n_layers ();
964936 uint32_t num_batches = block_transfer_info.size () * num_layers * 2 ;
965937 void ** srcs = new void *[num_batches];
@@ -970,24 +942,21 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
970942 size_t fail_index;
971943 uint32_t curr_index = 0 ;
972944
973- for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
974- auto src_k_cache = host_kv_caches_.at (layer_id).get_k_cache ();
975- auto dst_k_cache = kv_caches_.at (layer_id).get_k_cache ();
976- auto src_v_cache = host_kv_caches_.at (layer_id).get_v_cache ();
977- auto dst_v_cache = kv_caches_.at (layer_id).get_v_cache ();
978-
979- for (int idx = 0 ; idx < block_transfer_info.size (); idx++) {
980- srcs[curr_index] =
981- src_k_cache[block_transfer_info[idx].src_block_id ].data_ptr ();
982- dsts[curr_index] =
983- dst_k_cache[block_transfer_info[idx].dst_block_id ].data_ptr ();
945+ for (const auto & info : block_transfer_info) {
946+ auto src_k_cache = host_kv_caches_.at (info.src_block_id ).get_k_cache ();
947+ auto src_v_cache = host_kv_caches_.at (info.src_block_id ).get_v_cache ();
948+
949+ for (int layer_id = 0 ; layer_id < num_layers; layer_id++) {
950+ auto dst_k_cache = kv_caches_.at (layer_id).get_k_cache ();
951+ auto dst_v_cache = kv_caches_.at (layer_id).get_v_cache ();
952+
953+ srcs[curr_index] = src_k_cache[layer_id].data_ptr ();
954+ dsts[curr_index] = dst_k_cache[info.dst_block_id ].data_ptr ();
984955 copy_size[curr_index] = key_cache_size_per_layer_;
985956 curr_index++;
986957
987- srcs[curr_index] =
988- src_v_cache[block_transfer_info[idx].src_block_id ].data_ptr ();
989- dsts[curr_index] =
990- dst_v_cache[block_transfer_info[idx].dst_block_id ].data_ptr ();
958+ srcs[curr_index] = src_v_cache[layer_id].data_ptr ();
959+ dsts[curr_index] = dst_v_cache[info.dst_block_id ].data_ptr ();
991960 copy_size[curr_index] = value_cache_size_per_layer_;
992961 curr_index++;
993962 }
@@ -1021,4 +990,64 @@ bool WorkerImpl::h2d_batch_copy(Slice<BlockTransferInfo>& block_transfer_info) {
1021990 return false ;
1022991}
1023992
993+ uint32_t WorkerImpl::offload_to_store (
994+ Slice<BlockTransferInfo>& block_transfer_info) {
995+ if (!options_.enable_kvcache_store ()) {
996+ return block_transfer_info.size ();
997+ }
998+
999+ folly::Promise<uint32_t > promise;
1000+ auto future = promise.getSemiFuture ();
1001+
1002+ batchput_threadpool_.schedule (
1003+ [this , &block_transfer_info, promise = std::move (promise)]() mutable {
1004+ promise.setValue (
1005+ KVCacheStore::get_instance ().batch_put (block_transfer_info));
1006+ });
1007+
1008+ auto timeout = std::chrono::seconds (KVSTORE_TIMEOUT);
1009+ return std::move (future)
1010+ .via (folly::getGlobalCPUExecutor ())
1011+ .within (timeout)
1012+ .thenTry ([](folly::Try<uint32_t >&& t) -> uint32_t {
1013+ if (t.hasValue ()) {
1014+ return t.value ();
1015+ } else {
1016+ LOG (WARNING) << " BatchPut operation timed out" ;
1017+ return 0u ;
1018+ }
1019+ })
1020+ .get ();
1021+ }
1022+
1023+ uint32_t WorkerImpl::load_from_store (
1024+ Slice<BlockTransferInfo>& block_transfer_info) {
1025+ if (!options_.enable_kvcache_store ()) {
1026+ return 0 ;
1027+ }
1028+
1029+ folly::Promise<uint32_t > promise;
1030+ auto future = promise.getSemiFuture ();
1031+
1032+ batchget_threadpool_.schedule (
1033+ [this , &block_transfer_info, promise = std::move (promise)]() mutable {
1034+ promise.setValue (
1035+ KVCacheStore::get_instance ().batch_get (block_transfer_info));
1036+ });
1037+
1038+ auto timeout = std::chrono::seconds (KVSTORE_TIMEOUT);
1039+ return std::move (future)
1040+ .via (folly::getGlobalCPUExecutor ())
1041+ .within (timeout)
1042+ .thenTry ([](folly::Try<uint32_t >&& t) -> uint32_t {
1043+ if (t.hasValue ()) {
1044+ return t.value ();
1045+ } else {
1046+ LOG (WARNING) << " BatchGet operation timed out" ;
1047+ return 0u ;
1048+ }
1049+ })
1050+ .get ();
1051+ }
1052+
10241053} // namespace xllm
0 commit comments