@@ -26,11 +26,16 @@ def __init__(self, backend):
2626 self .init_sync_group = create_new_group_for_current_dp ("nccl" )
2727 dist .barrier (group = self .init_sync_group )
2828
29+ self .page_index_buffer = torch .empty ((1024 * 1024 * 4 ,), dtype = torch .int32 , device = "cuda" )
30+ self .page_ready_buffer = torch .empty ((1024 * 1024 * 4 ,), dtype = torch .bool , device = "cuda" )
31+
2932 self .cpu_cache_handle_queue : Deque [TransTask ] = deque ()
3033 self .cpu_cache_client = CpuKvCacheClient (only_create_meta_data = False , init_shm_data = False )
3134
3235 # 一些算子模式需要同步计算和 cpu cache 的 load 和 offload 操作
33- self .need_sync_compute_stream : bool = True
36+ self .need_sync_compute_stream : bool = (
37+ "fa3" in self .args .llm_decode_att_backend or "fa3" in self .args .llm_prefill_att_backend
38+ )
3439
3540 def wait (self ):
3641 """
@@ -89,14 +94,18 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]):
8994 cpu_kv_cache_scale = None
9095 gpu_kv_cache_scale = None
9196
97+ mem_indexes_cuda = mem_indexes .cuda (non_blocking = True )
98+ page_indexes_cuda = torch .tensor (need_pages , dtype = torch .int32 , device = "cpu" ).cuda (
99+ non_blocking = True
100+ )
92101 # 将 cpu page 的内容拷贝到 gpu 页面中
93102 load_cpu_kv_to_gpu (
94- gpu_mem_indexes = mem_indexes . cuda ( non_blocking = True ) ,
103+ gpu_mem_indexes = mem_indexes_cuda ,
95104 gpu_kv_cache = mem_manager .kv_buffer ,
96105 gpu_kv_cache_scale = gpu_kv_cache_scale ,
97106 cpu_kv_cache = cpu_kv_cache ,
98107 cpu_kv_cache_scale = cpu_kv_cache_scale ,
99- page_indexes = torch . tensor ( need_pages , dtype = torch . int32 , device = "cpu" ). cuda ( non_blocking = True ) ,
108+ page_indexes = page_indexes_cuda ,
100109 tp_index = self .backend .rank_in_dp ,
101110 tp_world_size = self .backend .dp_world_size ,
102111 grid_num = grid_num ,
@@ -221,6 +230,12 @@ def _start_kv_cache_offload_task(
221230
222231 page_indexes = torch .tensor (page_list , dtype = torch .int32 , device = "cpu" , pin_memory = True )
223232 page_readies = torch .tensor (ready_list , dtype = torch .bool , device = "cpu" , pin_memory = True )
233+ assert len (page_indexes ) <= self .page_index_buffer .shape [0 ]
234+ cuda_page_indexes = self .page_index_buffer [: len (page_indexes )]
235+ cuda_page_readies = self .page_ready_buffer [: len (page_readies )]
236+ cuda_page_indexes .copy_ (page_indexes , non_blocking = True )
237+ cuda_page_readies .copy_ (page_readies , non_blocking = True )
238+
224239 move_token_num = item_size * self .args .cpu_cache_token_page_size
225240 assert req .cur_kv_len >= item_size * self .args .cpu_cache_token_page_size
226241 token_indexes = self .backend .model .req_manager .req_to_token_indexs [req .req_idx , 0 :move_token_num ]
@@ -248,8 +263,8 @@ def _start_kv_cache_offload_task(
248263 gpu_kv_cache_scale = gpu_kv_cache_scale ,
249264 cpu_kv_cache = cpu_kv_cache ,
250265 cpu_kv_cache_scale = cpu_kv_cache_scale ,
251- page_indexes = page_indexes ,
252- page_readies = page_readies ,
266+ page_indexes = cuda_page_indexes ,
267+ page_readies = cuda_page_readies ,
253268 tp_index = self .backend .rank_in_dp ,
254269 tp_world_size = self .backend .dp_world_size ,
255270 grid_num = grid_num ,
0 commit comments