Skip to content

Commit a52ddbd

Browse files
authored
[Fix] Fix port error in PD disaggregation setting (#1175)
1 parent f94cd5f commit a52ddbd

File tree

3 files changed

+46
-11
lines changed

3 files changed

+46
-11
lines changed

slime/backends/sglang_utils/sglang_engine.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(self, args, rank: int, worker_type: str = "regular"):
9191
self.rank = rank
9292
self.worker_type = worker_type
9393

94-
def init(self, dist_init_addr, port, nccl_port, host=None):
94+
def init(self, dist_init_addr, port, nccl_port, host=None, disaggregation_bootstrap_port=None):
9595
self.router_ip = self.args.sglang_router_ip
9696
self.router_port = self.args.sglang_router_port
9797

@@ -108,7 +108,14 @@ def init(self, dist_init_addr, port, nccl_port, host=None):
108108
dist_init_addr = f"[{ipv6_addr}]:{port_str}"
109109

110110
server_args_dict, external_engine_need_check_fields = _compute_server_args(
111-
self.args, self.rank, dist_init_addr, nccl_port, host, port, self.worker_type
111+
self.args,
112+
self.rank,
113+
dist_init_addr,
114+
nccl_port,
115+
host,
116+
port,
117+
self.worker_type,
118+
disaggregation_bootstrap_port,
112119
)
113120

114121
self.node_rank = server_args_dict["node_rank"]
@@ -157,12 +164,15 @@ def _init_normal(self, server_args_dict):
157164
f"http://{self.router_ip}:{self.router_port}/add_worker?url=http://{self.server_host}:{self.server_port}"
158165
)
159166
else:
167+
payload = {
168+
"url": f"http://{self.server_host}:{self.server_port}",
169+
"worker_type": self.worker_type,
170+
}
171+
if self.worker_type == "prefill":
172+
payload["bootstrap_port"] = server_args_dict["disaggregation_bootstrap_port"]
160173
response = requests.post(
161174
f"http://{self.router_ip}:{self.router_port}/workers",
162-
json={
163-
"url": f"http://{self.server_host}:{self.server_port}",
164-
"worker_type": self.worker_type,
165-
},
175+
json=payload,
166176
)
167177
response.raise_for_status()
168178

@@ -381,7 +391,16 @@ def stop_profile(self):
381391
return response
382392

383393

384-
def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, worker_type: str = "regular"):
394+
def _compute_server_args(
395+
args,
396+
rank,
397+
dist_init_addr,
398+
nccl_port,
399+
host,
400+
port,
401+
worker_type: str = "regular",
402+
disaggregation_bootstrap_port: int | None = None,
403+
):
385404
nnodes = max(1, args.rollout_num_gpus_per_engine // args.num_gpus_per_node)
386405
node_rank = rank % nnodes
387406
kwargs = {
@@ -411,6 +430,10 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work
411430
if worker_type == "prefill":
412431
kwargs["disaggregation_mode"] = "prefill"
413432
kwargs["load_balance_method"] = "round_robin"
433+
assert (
434+
disaggregation_bootstrap_port is not None
435+
), "disaggregation_bootstrap_port must be set for prefill worker"
436+
kwargs["disaggregation_bootstrap_port"] = disaggregation_bootstrap_port
414437
elif worker_type == "decode":
415438
kwargs["disaggregation_mode"] = "decode"
416439
kwargs["prefill_round_robin_balance"] = True
@@ -419,7 +442,6 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work
419442
kwargs["enable_return_routed_experts"] = True
420443
if args.fp16:
421444
kwargs["dtype"] = "float16"
422-
423445
external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS]
424446

425447
unused_keys = set(kwargs.keys())

slime/ray/rollout.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,12 @@ def _allocate_rollout_engine_addr_and_ports_normal(*, args, num_engines, rollout
435435
)
436436
addr_and_ports = [{} for _ in range(num_engines)]
437437

438+
# Calculate prefill limit to identify prefill engines
439+
prefill_limit = 0
440+
if args.prefill_num_servers is not None:
441+
num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node)
442+
prefill_limit = args.prefill_num_servers * args.rollout_num_gpus_per_engine // num_gpu_per_engine
443+
438444
visited_nodes = set()
439445
for rank, engine in rollout_engines:
440446
if rank // num_engines_per_node in visited_nodes:
@@ -469,9 +475,13 @@ def addr():
469475
get_addr, get_port = get_addr_and_ports(engine)
470476

471477
for i in range(num_engines_on_this_node):
472-
addr_and_ports[rank + i]["host"] = get_addr()
473-
addr_and_ports[rank + i]["port"] = get_port()
474-
addr_and_ports[rank + i]["nccl_port"] = get_port()
478+
current_rank = rank + i
479+
addr_and_ports[current_rank]["host"] = get_addr()
480+
addr_and_ports[current_rank]["port"] = get_port()
481+
addr_and_ports[current_rank]["nccl_port"] = get_port()
482+
483+
if args.prefill_num_servers is not None and current_rank < prefill_limit:
484+
addr_and_ports[current_rank]["disaggregation_bootstrap_port"] = get_port()
475485

476486
if args.rollout_num_gpus_per_engine > args.num_gpus_per_node:
477487
num_node_per_engine = args.rollout_num_gpus_per_engine // args.num_gpus_per_node

slime/utils/arguments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,9 @@ def slime_validate_args(args):
15811581

15821582
if args.prefill_num_servers is not None:
15831583
assert not args.use_fault_tolerance, "fault tolerance is not supported when prefill_num_servers is set."
1584+
assert not (
1585+
args.prefill_num_servers is not None and args.rollout_external
1586+
), "prefill_num_servers cannot be set when rollout_external is set."
15841587

15851588

15861589
def hf_validate_args(args, hf_config):

0 commit comments

Comments
 (0)