@@ -159,10 +159,12 @@ def _fetch_dp_prompt_cache(
159159 my_match .append ((shm_req , kv_len , value_tensor ))
160160
161161 # match all the reqs in other dp ranks.
162+ other_shm_reqs = []
162163 if self .rank_in_dp == 0 :
163164 for r in other_reqs :
164165 _ , shm_index , _ , _ = r
165166 shm_req = g_infer_context .shm_req_manager .get_req_obj_by_index (shm_index )
167+ other_shm_reqs .append (shm_req )
166168 sampling_param = InferSamplingParams (shm_req , g_infer_context .vocab_size )
167169 if sampling_param .disable_prompt_cache :
168170 continue
@@ -190,9 +192,12 @@ def _fetch_dp_prompt_cache(
190192 shm_req , kv_len , value_tensor = shm_index_to_match [shm_index ]
191193 match = (shm_req , kv_len , value_tensor , suggested_dp_index )
192194
195+ # 需要传输的
193196 if suggested_dp_index != shm_req .dp_max_kv_rank :
197+ # 需要获取的
194198 if suggested_dp_index == self .dp_rank_in_node :
195199 my_trans_match .append ((match , transfer_count ))
200+ # 需要给其他dp的
196201 else :
197202 other_trans_match .append ((match , transfer_count ))
198203 transfer_count += 1
@@ -206,16 +211,15 @@ def _fetch_dp_prompt_cache(
206211 if transfer_count > 0 :
207212 self ._transfer_dp_kv_cache (my_trans_match , other_trans_match )
208213
214+ self .release_all_shm_reqs (other_shm_reqs )
215+
209216 def _transfer_dp_kv_cache (self , my_match : List [Tuple ], other_match : List [Tuple ]):
210- other_shm_reqs = []
211217 for match , index in other_match :
212218 shm_req , kv_len , value_tensor , _ = match
213219 trans_len = kv_len - shm_req .dp_origin_kv_len
214220 if shm_req .dp_max_kv_rank == self .dp_rank_in_node :
215221 self .shared_kv_indexes .arr [index , 0 :trans_len ] = value_tensor [shm_req .dp_origin_kv_len : kv_len ]
216- other_shm_reqs .append (shm_req )
217222
218- self .release_all_shm_reqs (other_shm_reqs )
219223 dist .barrier (group = self .node_nccl_group )
220224
221225 if not my_match :
0 commit comments