Skip to content

Commit 9a1a416

Browse files
committed
fix release other shm_reqs
1 parent 93cb841 commit 9a1a416

File tree

1 file changed

+7
-3
lines changed
  • lightllm/server/router/model_infer/mode_backend/dp_backend

1 file changed

+7
-3
lines changed

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,12 @@ def _fetch_dp_prompt_cache(
159159
my_match.append((shm_req, kv_len, value_tensor))
160160

161161
# match all the reqs in other dp ranks.
162+
other_shm_reqs = []
162163
if self.rank_in_dp == 0:
163164
for r in other_reqs:
164165
_, shm_index, _, _ = r
165166
shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(shm_index)
167+
other_shm_reqs.append(shm_req)
166168
sampling_param = InferSamplingParams(shm_req, g_infer_context.vocab_size)
167169
if sampling_param.disable_prompt_cache:
168170
continue
@@ -190,9 +192,12 @@ def _fetch_dp_prompt_cache(
190192
shm_req, kv_len, value_tensor = shm_index_to_match[shm_index]
191193
match = (shm_req, kv_len, value_tensor, suggested_dp_index)
192194

195+
# 需要传输的
193196
if suggested_dp_index != shm_req.dp_max_kv_rank:
197+
# 需要获取的
194198
if suggested_dp_index == self.dp_rank_in_node:
195199
my_trans_match.append((match, transfer_count))
200+
# 需要给其他dp的
196201
else:
197202
other_trans_match.append((match, transfer_count))
198203
transfer_count += 1
@@ -206,16 +211,15 @@ def _fetch_dp_prompt_cache(
206211
if transfer_count > 0:
207212
self._transfer_dp_kv_cache(my_trans_match, other_trans_match)
208213

214+
self.release_all_shm_reqs(other_shm_reqs)
215+
209216
def _transfer_dp_kv_cache(self, my_match: List[Tuple], other_match: List[Tuple]):
210-
other_shm_reqs = []
211217
for match, index in other_match:
212218
shm_req, kv_len, value_tensor, _ = match
213219
trans_len = kv_len - shm_req.dp_origin_kv_len
214220
if shm_req.dp_max_kv_rank == self.dp_rank_in_node:
215221
self.shared_kv_indexes.arr[index, 0:trans_len] = value_tensor[shm_req.dp_origin_kv_len : kv_len]
216-
other_shm_reqs.append(shm_req)
217222

218-
self.release_all_shm_reqs(other_shm_reqs)
219223
dist.barrier(group=self.node_nccl_group)
220224

221225
if not my_match:

0 commit comments

Comments
 (0)