Skip to content

Commit 3d879c6

Browse files
sywangyigemini-code-assist[bot]hnyls2002cursoragent
authored
refactor: extract device-to-backend mapping into get_default_distributed_backend (sgl-project#19202)
Signed-off-by: Wang, Yi <yi.a.wang@intel.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Liangsheng Yin <lsyincs@gmail.com> Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
1 parent d0bb140 commit 3d879c6

File tree

3 files changed

+28
-23
lines changed

3 files changed

+28
-23
lines changed

python/sglang/srt/disaggregation/encode_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sglang.srt.configs.model_config import ModelConfig
2626
from sglang.srt.disaggregation.encode_receiver import EmbeddingData
2727
from sglang.srt.distributed.parallel_state import (
28+
get_default_distributed_backend,
2829
get_mooncake_transfer_engine,
2930
get_tp_group,
3031
init_distributed_environment,
@@ -176,6 +177,7 @@ def __init__(
176177
self.use_image_processor_gpu = use_image_processor_gpu
177178

178179
init_distributed_environment(
180+
backend=get_default_distributed_backend(self.device),
179181
world_size=server_args.tp_size,
180182
rank=rank,
181183
distributed_init_method=dist_init_method,

python/sglang/srt/distributed/parallel_state.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,6 +1578,20 @@ def set_torch_symm_mem_all_reduce(enable: bool):
15781578
_ENABLE_TORCH_SYMM_MEM_ALL_REDUCE = enable
15791579

15801580

1581+
_DEVICE_TO_DISTRIBUTED_BACKEND = {
1582+
"cuda": "nccl",
1583+
"xpu": "xccl",
1584+
"hpu": "hccl",
1585+
"cpu": "gloo",
1586+
"npu": "hccl",
1587+
"musa": "mccl",
1588+
}
1589+
1590+
1591+
def get_default_distributed_backend(device: str) -> str:
1592+
return _DEVICE_TO_DISTRIBUTED_BACKEND.get(device, "gloo")
1593+
1594+
15811595
def init_distributed_environment(
15821596
world_size: int = -1,
15831597
rank: int = -1,

python/sglang/srt/model_executor/model_runner.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
register_forward_hook_for_model,
5656
)
5757
from sglang.srt.distributed import (
58+
get_default_distributed_backend,
5859
get_pp_group,
5960
get_tp_group,
6061
get_world_group,
@@ -734,29 +735,17 @@ def init_torch_distributed(self):
734735
)
735736
raise
736737

737-
if self.device == "cuda":
738-
if self.server_args.elastic_ep_backend == "mooncake":
739-
backend = "mooncake"
740-
if self.server_args.mooncake_ib_device:
741-
mooncake_ib_device = self.server_args.mooncake_ib_device.split(",")
742-
try:
743-
from mooncake import ep as mooncake_ep
744-
745-
mooncake_ep.set_device_filter(mooncake_ib_device)
746-
except:
747-
pass # A warning will be raised in `init_distributed_environment`
748-
else:
749-
backend = "nccl"
750-
elif self.device == "xpu":
751-
backend = "xccl"
752-
elif self.device == "hpu":
753-
backend = "hccl"
754-
elif self.device == "cpu":
755-
backend = "gloo"
756-
elif self.device == "npu":
757-
backend = "hccl"
758-
elif self.device == "musa":
759-
backend = "mccl"
738+
backend = get_default_distributed_backend(self.device)
739+
if self.device == "cuda" and self.server_args.elastic_ep_backend == "mooncake":
740+
backend = "mooncake"
741+
if self.server_args.mooncake_ib_device:
742+
mooncake_ib_device = self.server_args.mooncake_ib_device.split(",")
743+
try:
744+
from mooncake import ep as mooncake_ep
745+
746+
mooncake_ep.set_device_filter(mooncake_ib_device)
747+
except:
748+
pass # A warning will be raised in `init_distributed_environment`
760749

761750
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
762751
if not self.server_args.enable_p2p_check:

0 commit comments

Comments
 (0)