@@ -51,7 +51,6 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
5151 self .schedule_time_interval = args .schedule_time_interval # 默认30ms 的调度周期
5252 # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
5353 self .dp_size_in_node = max (1 , args .dp // self .nnodes )
54- self .dp_world_size = self .world_size // self .dp_size
5554 self .is_multinode_tp = args .nnodes > 1 and args .dp == 1
5655 self .is_multinode_tp_master = self .is_multinode_tp and args .node_rank == 0
5756 self .is_multinode_tp_slave = self .is_multinode_tp and args .node_rank != 0
@@ -114,14 +113,12 @@ async def wait_to_model_ready(self):
114113 self .mem_queues : List [torch .multiprocessing .Queue ] = [
115114 torch .multiprocessing .Queue () for _ in range (self .node_world_size )
116115 ]
117- self .result_queues : List [mp .Queue ] = [mp .Queue () for _ in range (self .node_world_size )]
118116 self .rpc_event = multiprocessing .Event ()
119117 self .rpc_finished_event = multiprocessing .Event ()
120118
121119 assert (self .world_size % self .nnodes ) == 0
122120 node_world_size = self .world_size // self .nnodes
123121 for rank_id in range (self .node_rank * node_world_size , (self .node_rank + 1 ) * node_world_size ):
124-
125122 rpc_model = await start_model_process (
126123 args = self .args ,
127124 rank = rank_id ,
@@ -130,8 +127,7 @@ async def wait_to_model_ready(self):
130127 rpc_event = self .rpc_event ,
131128 rpc_finished_event = self .rpc_finished_event ,
132129 info_queue = self .info_queue ,
133- result_queue = self .result_queues [rank_id % node_world_size ],
134- mem_queue = self .mem_queues [rank_id % node_world_size ],
130+ mem_queue = self .mem_queues [(rank_id % node_world_size )],
135131 router_lock = self .router_lock ,
136132 )
137133 self .model_rpc_servers .append (rpc_model )
@@ -184,7 +180,7 @@ async def wait_to_model_ready(self):
184180 get_unique_server_name (),
185181 self .max_total_token_num ,
186182 node_world_size = self .node_world_size ,
187- dp_world_size = self .dp_world_size ,
183+ dp_world_size = self .world_size // self . dp_size ,
188184 )
189185 self .req_queue = build_req_queue (self .args , self , self .dp_size_in_node )
190186 logger .info (f"use req queue { self .req_queue .__class__ .__name__ } " )
@@ -197,30 +193,6 @@ async def wait_to_model_ready(self):
197193
198194 start_prefill_kv_move_manager_process (self .args , self .info_queue , self .mem_queues )
199195
200- if self .args .run_mode == "nixl_prefill" :
201- from lightllm .server .router .model_infer .mode_backend .pd_nixl .pd_remote_prefill import (
202- start_pd_remote_prefill_server_process ,
203- )
204-
205- dist_info = DistInfo (
206- self .world_size ,
207- self .nnodes ,
208- self .dp_size ,
209- self .dp_world_size ,
210- self .dp_size_in_node ,
211- self .node_world_size ,
212- )
213-
214- start_pd_remote_prefill_server_process (
215- self .args .pd_node_id ,
216- dist_info = dist_info ,
217- http_server_port = self .args .pd_nixl_remote_prefill_http_port ,
218- server_port = self .args .pd_nixl_remote_prefill_port ,
219- from_backend_queue = self .info_queue ,
220- to_backend_queues = self .result_queues ,
221- agent_meta_queues = self .mem_queues ,
222- )
223-
224196 if self .args .run_mode == "decode" :
225197 # 启动 decode kv move 管理进程
226198 from lightllm .server .router .model_infer .mode_backend .continues_batch .pd_mode .decode_node_impl import (
@@ -229,28 +201,6 @@ async def wait_to_model_ready(self):
229201
230202 start_decode_kv_move_manager_process (self .args , self .info_queue , self .mem_queues )
231203
232- if self .args .run_mode == "nixl_decode" :
233- from lightllm .server .router .model_infer .mode_backend .pd_nixl .pd_remote_prefill import (
234- start_pd_remote_prefill_client_process ,
235- )
236-
237- dist_info = DistInfo (
238- self .world_size ,
239- self .nnodes ,
240- self .dp_size ,
241- self .dp_world_size ,
242- self .dp_size_in_node ,
243- self .node_world_size ,
244- )
245-
246- start_pd_remote_prefill_client_process (
247- self .args .pd_node_id ,
248- dist_info ,
249- from_backend_queue = self .info_queue ,
250- to_backend_queues = self .result_queues ,
251- agent_meta_queues = self .mem_queues ,
252- )
253-
254204 return
255205
256206 def _get_schedule_time_interval (self ):
@@ -459,7 +409,8 @@ def _add_req(self, group_req_indexes: GroupReqIndexes):
459409 req ._router_stop_str_matched = False
460410
461411 if isinstance (req , PDNIXLChunkedPrefillReq ):
462- req .set_dp_world_size (self .dp_world_size )
412+ dp_world_size = self .world_size // self .dp_size
413+ req .set_dp_world_size (dp_world_size )
463414
464415 req_group .append (req )
465416
0 commit comments