@@ -91,7 +91,7 @@ def __init__(self, args, rank: int, worker_type: str = "regular"):
9191 self .rank = rank
9292 self .worker_type = worker_type
9393
94- def init (self , dist_init_addr , port , nccl_port , host = None ):
94+ def init (self , dist_init_addr , port , nccl_port , host = None , disaggregation_bootstrap_port = None ):
9595 self .router_ip = self .args .sglang_router_ip
9696 self .router_port = self .args .sglang_router_port
9797
@@ -108,7 +108,14 @@ def init(self, dist_init_addr, port, nccl_port, host=None):
108108 dist_init_addr = f"[{ ipv6_addr } ]:{ port_str } "
109109
110110 server_args_dict , external_engine_need_check_fields = _compute_server_args (
111- self .args , self .rank , dist_init_addr , nccl_port , host , port , self .worker_type
111+ self .args ,
112+ self .rank ,
113+ dist_init_addr ,
114+ nccl_port ,
115+ host ,
116+ port ,
117+ self .worker_type ,
118+ disaggregation_bootstrap_port ,
112119 )
113120
114121 self .node_rank = server_args_dict ["node_rank" ]
@@ -157,12 +164,15 @@ def _init_normal(self, server_args_dict):
157164 f"http://{ self .router_ip } :{ self .router_port } /add_worker?url=http://{ self .server_host } :{ self .server_port } "
158165 )
159166 else :
167+ payload = {
168+ "url" : f"http://{ self .server_host } :{ self .server_port } " ,
169+ "worker_type" : self .worker_type ,
170+ }
171+ if self .worker_type == "prefill" :
172+ payload ["bootstrap_port" ] = server_args_dict ["disaggregation_bootstrap_port" ]
160173 response = requests .post (
161174 f"http://{ self .router_ip } :{ self .router_port } /workers" ,
162- json = {
163- "url" : f"http://{ self .server_host } :{ self .server_port } " ,
164- "worker_type" : self .worker_type ,
165- },
175+ json = payload ,
166176 )
167177 response .raise_for_status ()
168178
@@ -381,7 +391,16 @@ def stop_profile(self):
381391 return response
382392
383393
384- def _compute_server_args (args , rank , dist_init_addr , nccl_port , host , port , worker_type : str = "regular" ):
394+ def _compute_server_args (
395+ args ,
396+ rank ,
397+ dist_init_addr ,
398+ nccl_port ,
399+ host ,
400+ port ,
401+ worker_type : str = "regular" ,
402+ disaggregation_bootstrap_port : int | None = None ,
403+ ):
385404 nnodes = max (1 , args .rollout_num_gpus_per_engine // args .num_gpus_per_node )
386405 node_rank = rank % nnodes
387406 kwargs = {
@@ -411,6 +430,10 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work
411430 if worker_type == "prefill" :
412431 kwargs ["disaggregation_mode" ] = "prefill"
413432 kwargs ["load_balance_method" ] = "round_robin"
433+ assert (
434+ disaggregation_bootstrap_port is not None
435+ ), "disaggregation_bootstrap_port must be set for prefill worker"
436+ kwargs ["disaggregation_bootstrap_port" ] = disaggregation_bootstrap_port
414437 elif worker_type == "decode" :
415438 kwargs ["disaggregation_mode" ] = "decode"
416439 kwargs ["prefill_round_robin_balance" ] = True
@@ -419,7 +442,6 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work
419442 kwargs ["enable_return_routed_experts" ] = True
420443 if args .fp16 :
421444 kwargs ["dtype" ] = "float16"
422-
423445 external_engine_need_check_fields = [k for k in kwargs .keys () if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS ]
424446
425447 unused_keys = set (kwargs .keys ())
0 commit comments