Skip to content

Commit 6418a1e

Browse files
committed
fix bug for pd reformater.
1 parent 16e67e3 commit 6418a1e

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def init_custom(self):
3535
ranks.append(i + self.global_dp_rank * self.dp_world_size)
3636

3737
self.lock_nccl_group = dist.new_group(ranks=ranks, backend="gloo")
38-
logger.info(f"lock_nccl_group ranks {self.lock_nccl_group.get_rank()}")
38+
logger.info(f"lock_nccl_group ranks {dist.get_rank(self.lock_nccl_group)}")
3939

4040
from .decode_infer_rpyc import PDDecodeInferRpcServer
4141

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,13 @@ def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optiona
416416

417417
trans_obj = self.get_trans_obj(tasks[0])
418418
assert trans_obj is not None
419-
420-
id_to_test_range = {task.group_request_id: random.shuffle(list(range(self.dp_size_in_node))) for task in tasks}
419+
420+
id_to_test_range = {}
421+
for task in tasks:
422+
test_dp_indexes = list(range(self.dp_size_in_node))
423+
random.shuffle(test_dp_indexes)
424+
id_to_test_range[task.group_request_id] = test_dp_indexes
425+
421426
id_has_result = {}
422427
for test_index in range(self.dp_size_in_node):
423428
dp_tasks = [[] for _ in range(self.dp_size_in_node)]

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def init_custom(self):
3535
ranks.append(i + self.global_dp_rank * self.dp_world_size)
3636

3737
self.lock_nccl_group = dist.new_group(ranks=ranks, backend="gloo")
38-
logger.info(f"lock_nccl_group ranks {self.lock_nccl_group.get_rank()}")
38+
logger.info(f"lock_nccl_group ranks {dist.get_rank(self.lock_nccl_group)}")
3939

4040
from .prefill_infer_rpyc import PDPrefillInferRpcServer
4141

0 commit comments

Comments
 (0)