99from ..infer_batch import InferReq
1010from lightllm .utils .dist_utils import create_new_group_for_current_dp
1111from lightllm .common .basemodel .triton_kernel .kv_cache_offload import offload_gpu_kv_to_cpu
12+ from lightllm .server .router .model_infer .infer_batch import g_infer_context
1213
1314
1415class MultiLevelCacheManager (object ):
@@ -20,11 +21,14 @@ def __init__(self, backend):
2021 self .gloo_group = create_new_group_for_current_dp ("gloo" )
2122 self .filter_group = create_new_group_for_current_dp ("gloo" )
2223 self .sync_group = create_new_group_for_current_dp ("nccl" )
24+ self .init_sync_group = create_new_group_for_current_dp ("nccl" )
2325
2426 self .cpu_cache_handle_queue = deque ()
2527 self .cpu_cache_client = CpuKvCacheClient (init_shm_data = False )
2628
27- def req_to_cpu_cache_task (self , req : InferReq , cpu_kv_cache_stream : torch .cuda .Stream ) -> Optional ["TransTask" ]:
29+ def start_kv_cache_offload_task (
30+ self , req : InferReq , cpu_kv_cache_stream : torch .cuda .Stream
31+ ) -> Optional ["TransTask" ]:
2832 with torch .cuda .stream (cpu_kv_cache_stream ):
2933 all_token_hash_list = req .shm_req .token_hash_list .get_all ()
3034 block_size = req .cur_kv_len // self .args .cpu_cache_token_chuncked_size
@@ -79,7 +83,7 @@ def req_to_cpu_cache_task(self, req: InferReq, cpu_kv_cache_stream: torch.cuda.S
7983
8084 return trans_task
8185
82- def handle_task_queue (self ):
86+ def update_kv_cache_offload_task_states (self ):
8387 if self .backend .is_master_in_dp :
8488 trans_ok_reqs = []
8589 while len (self .cpu_cache_handle_queue ) != 0 :
@@ -110,6 +114,37 @@ def handle_task_queue(self):
110114 req .req_obj .cpu_cache_task_finished = True
111115 return
112116
117+ def fill_cpu_cache_to_reqs (self , reqs : List [InferReq ]):
118+ idle_token_num = g_infer_context .get_can_alloc_token_num ()
119+ token_chuncked_size = self .args .cpu_cache_token_chuncked_size
120+ all_page_list = []
121+ for req in reqs :
122+ page_list = req .shm_req .cpu_cache_match_page_indexes .get_all ()
123+ match_tokens = len (page_list ) * token_chuncked_size
124+ need_token_num = match_tokens - req .cur_kv_len
125+ # 多匹配了一定数量的token 才进行复制操作,不然操作效率不高
126+ if need_token_num > 256 :
127+ if need_token_num <= idle_token_num :
128+ if self .backend .radix_cache is not None :
129+ g_infer_context .radix_cache .free_radix_cache_to_get_enough_token (need_token_num = need_token_num )
130+
131+ mem_indexes = g_infer_context .req_manager .mem_manager .alloc (need_size = need_token_num )
132+ idle_token_num -= need_token_num
133+ g_infer_context .req_manager .req_to_token_indexs [
134+ req .req_idx , req .cur_kv_len : (req .cur_kv_len + need_token_num )
135+ ] = mem_indexes
136+ req .cur_kv_len = req .cur_kv_len + need_token_num
137+ if self .backend .is_master_in_dp :
138+ req .shm_req .shm_cur_kv_len = req .cur_kv_len
139+
140+ all_page_list .extend (page_list )
141+
142+ if self .backend .is_master_in_dp :
143+ self .cpu_cache_client .lock .acquire_sleep1ms ()
144+ self .cpu_cache_client .deref_pages (page_list = all_page_list )
145+ self .cpu_cache_client .lock .release ()
146+ return
147+
113148
114149@dataclasses .dataclass
115150class TransTask :
0 commit comments