Skip to content

Commit f3ebb9d

Browse files
committed
fix create new group for current dp
1 parent 644e802 commit f3ebb9d

File tree

4 files changed

+19
-16
lines changed

4 files changed

+19
-16
lines changed

lightllm/distributed/communication_op.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,11 @@
2727
from lightllm.utils.device_utils import has_nvlink
2828
from lightllm.utils.envs_utils import get_env_start_args, get_deepep_num_max_dispatch_tokens_per_rank
2929
from lightllm.utils.dist_utils import (
30-
get_current_device_id,
31-
get_node_world_size,
3230
get_global_world_size,
3331
get_dp_world_size,
3432
get_global_rank,
3533
get_current_rank_in_dp,
34+
create_new_group_for_current_dp,
3635
)
3736
from lightllm.utils.device_utils import get_device_sm_count
3837
from lightllm.utils.sgl_utils import HAS_SGL_KERNEL
@@ -63,17 +62,15 @@ def __init__(self):
6362
self.custom_reduce = None
6463
self.custom_gather = None
6564
self.dp_world_size = get_dp_world_size()
66-
ranks = list([get_global_rank() - get_current_rank_in_dp() + i for i in range(self.dp_world_size)])
67-
self.device_group = dist.new_group(ranks, backend="nccl")
65+
self.device_group = create_new_group_for_current_dp("nccl")
6866

6967
def init_custom_reduce(self) -> None:
7068
if not HAS_SGL_KERNEL or not has_nvlink() or self.dp_world_size not in [2, 4, 6, 8]:
7169
return
7270
args = get_env_start_args()
7371
if args.disable_custom_allreduce:
7472
return
75-
ranks = list([get_global_rank() - get_current_rank_in_dp() + i for i in range(self.dp_world_size)])
76-
cpu_group = dist.new_group(ranks, backend="gloo")
73+
cpu_group = create_new_group_for_current_dp("gloo")
7774
self.custom_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device())
7875
logger.info("Enable Custom ALLReduce. You can disable it by settting --disable_custom_allreduce.")
7976

@@ -84,8 +81,8 @@ def init_custom_gather(self) -> None:
8481
args = get_env_start_args()
8582
if args.disable_custom_allgather:
8683
return
87-
ranks = list([get_global_rank() - get_current_rank_in_dp() + i for i in range(self.dp_world_size)])
88-
cpu_group = dist.new_group(ranks, backend="gloo")
84+
85+
cpu_group = create_new_group_for_current_dp("gloo")
8986
self.custom_gather = CustomAllgather(cpu_group, torch.cuda.current_device())
9087
logger.info("Enable Custom ALLGather. You can disable it by settting --disable_custom_allgather")
9188

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .decode_task_cache import g_success_kv_move_task_cache, KVMoveTask
2020
from lightllm.utils.device_utils import kv_trans_use_p2p
2121
from lightllm.utils.envs_utils import get_unique_server_name
22+
from lightllm.utils.dist_utils import create_new_group_for_current_dp
2223

2324
logger = init_logger(__name__)
2425

@@ -30,11 +31,8 @@ def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None:
3031
self.mem_queue: mp.Queue = mem_queue
3132

3233
def init_custom(self):
33-
ranks = []
34-
for i in range(self.dp_world_size):
35-
ranks.append(i + self.global_dp_rank * self.dp_world_size)
3634

37-
self.lock_nccl_group = dist.new_group(ranks=ranks, backend="gloo")
35+
self.lock_nccl_group = create_new_group_for_current_dp("gloo")
3836
logger.info(f"lock_nccl_group ranks {dist.get_rank(self.lock_nccl_group)}")
3937

4038
from .decode_infer_rpyc import PDDecodeInferRpcServer

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .prefill_task_cache import g_kv_move_task_cache
2020
from lightllm.utils.device_utils import kv_trans_use_p2p
2121
from lightllm.utils.envs_utils import get_unique_server_name
22+
from lightllm.utils.dist_utils import create_new_group_for_current_dp
2223

2324
logger = init_logger(__name__)
2425

@@ -30,11 +31,8 @@ def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None:
3031
self.mem_queue: mp.Queue = mem_queue
3132

3233
def init_custom(self):
33-
ranks = []
34-
for i in range(self.dp_world_size):
35-
ranks.append(i + self.global_dp_rank * self.dp_world_size)
3634

37-
self.lock_nccl_group = dist.new_group(ranks=ranks, backend="gloo")
35+
self.lock_nccl_group = create_new_group_for_current_dp("gloo")
3836
logger.info(f"lock_nccl_group ranks {dist.get_rank(self.lock_nccl_group)}")
3937

4038
from .prefill_infer_rpyc import PDPrefillInferRpcServer

lightllm/utils/dist_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,13 @@ def set_node_world_size(node_world_size: int):
189189

190190
def get_node_world_size():
191191
return int(get_environ("LIGHTLLM_NODE_WORLD_SIZE"))
192+
193+
194+
def create_new_group_for_current_dp(backend):
195+
ans_group = None
196+
for iter_dp_rank in range(get_dp_size()):
197+
ranks = list(i + iter_dp_rank * get_dp_world_size() for i in range(get_dp_world_size()))
198+
device_group = dist.new_group(ranks, backend=backend)
199+
if get_global_dp_rank() == iter_dp_rank:
200+
ans_group = device_group
201+
return ans_group

0 commit comments

Comments
 (0)