@@ -47,7 +47,7 @@ def __init__(self) -> None:
4747 self .enable_decode_microbatch_overlap = get_env_start_args ().enable_decode_microbatch_overlap
4848 self .enable_prefill_microbatch_overlap = get_env_start_args ().enable_prefill_microbatch_overlap
4949
50- # 控制分类的参数变量
50+ # 控制 _get_classed_reqs 分类的参数变量,不同的 backend 具有可能需要不同的分类运行条件。
5151 self .classed_req_no_decode = False
5252 self .classed_req_strict_prefill = False
5353 pass
@@ -73,6 +73,7 @@ def init_model(self, kvargs):
7373 self .use_dynamic_prompt_cache = not self .args .disable_dynamic_prompt_cache
7474 self .eos_id : List [int ] = kvargs .get ("eos_id" , [2 ])
7575 self .disable_cudagraph = self .args .disable_cudagraph
76+ self .is_multinode_tp = self .args .nnodes > 1 and self .args .dp == 1
7677
7778 self .logger = init_logger (__name__ )
7879
@@ -165,17 +166,29 @@ def init_model(self, kvargs):
165166 [0 for _ in range (self .global_world_size )], dtype = torch .int32 , device = "cuda" , requires_grad = False
166167 )
167168
169+ # 用于协同读取 ShmReqsIOBuffer 中的请求信息的通信tensor和通信组对象。
168170 self .node_broadcast_tensor = torch .tensor ([0 ], dtype = torch .int32 , device = "cuda" , requires_grad = False )
169171 self .node_nccl_group = create_new_group_for_current_node ("nccl" )
170172
173+ # 用于在多节点tp模式下协同读取 ShmReqsIOBuffer 中的请求信息的通信tensor和通信组对象。
174+ if self .is_multinode_tp :
175+ self .multinode_tp_gather_item_tensor = torch .tensor ([0 ], dtype = torch .int32 , device = "cuda" )
176+ self .multinode_tp_all_gather_tensor = torch .tensor (
177+ [0 for _ in range (self .global_world_size )], dtype = torch .int32 , device = "cuda" , requires_grad = False
178+ )
179+ self .multinode_tp_nccl_group = dist .new_group (
180+ [rank for rank in range (self .global_world_size )], backend = "nccl"
181+ )
182+
171183 self .init_custom ()
172184 self .shm_reqs_io_buffer = ShmReqsIOBuffer ()
173185
174186 # 开启 mtp 模式,需要完成mtp model的初始化
175187 if self .args .mtp_mode :
176188 self .init_mtp_draft_model (kvargs )
177189
178- # 启动infer_loop_thread
190+ # 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景
191+ # 可以降低 cpu overhead,大幅提升gpu得使用率。
179192 self .infer_loop_thread = threading .Thread (target = self .infer_loop , daemon = True )
180193 self .infer_loop_thread .start ()
181194 self .infer_loop_thread1 = threading .Thread (target = self .infer_loop , daemon = True )
@@ -238,6 +251,13 @@ def init_mtp_draft_model(self, main_kvargs: dict):
238251 return
239252
240253 def _try_read_new_reqs (self ):
254+ if self .is_multinode_tp :
255+ self ._try_read_new_reqs_multinode_tp ()
256+ else :
257+ self ._try_read_new_reqs_normal ()
258+ return
259+
260+ def _try_read_new_reqs_normal (self ):
241261 if self .is_master_in_node :
242262 if self .shm_reqs_io_buffer .is_ready ():
243263 self .node_broadcast_tensor .fill_ (1 )
@@ -246,16 +266,42 @@ def _try_read_new_reqs(self):
246266 dist .broadcast (self .node_broadcast_tensor , src = 0 , group = self .node_nccl_group , async_op = False )
247267 new_buffer_is_ready = self .node_broadcast_tensor .detach ().item ()
248268 if new_buffer_is_ready :
249- cmds : List = self .shm_reqs_io_buffer .read_obj ()
250- self .shm_reqs_io_buffer .sub_state ()
251- if cmds :
252- if isinstance (cmds [0 ], AbortedReqCmd ):
253- for obj in cmds :
254- if obj .req_id in g_infer_context .requests_mapping :
255- req : InferReq = g_infer_context .requests_mapping [obj .req_id ]
256- req .infer_aborted = True
257- else :
258- self ._init_reqs (reqs = cmds )
269+ self ._read_reqs_buffer_and_init_reqs ()
270+ return
271+
272+ def _try_read_new_reqs_multinode_tp (self ):
273+ """
274+ 多节点tp模式下,需要协调所有rank的行为同步。
275+ """
276+ if self .shm_reqs_io_buffer .is_ready ():
277+ self .multinode_tp_gather_item_tensor .fill_ (1 )
278+ else :
279+ self .multinode_tp_gather_item_tensor .fill_ (0 )
280+ dist .all_gather_into_tensor (
281+ self .multinode_tp_all_gather_tensor ,
282+ self .multinode_tp_gather_item_tensor ,
283+ group = self .multinode_tp_nccl_group ,
284+ async_op = False ,
285+ )
286+ new_buffer_is_readys = self .multinode_tp_all_gather_tensor .detach ().cpu ().numpy ()
287+ new_buffer_is_ready = np .all (new_buffer_is_readys == 1 )
288+
289+ if new_buffer_is_ready :
290+ self ._read_reqs_buffer_and_init_reqs ()
291+ return
292+
293+ def _read_reqs_buffer_and_init_reqs (self ):
294+ cmds : List = self .shm_reqs_io_buffer .read_obj ()
295+ self .shm_reqs_io_buffer .sub_state ()
296+ if cmds :
297+ if isinstance (cmds [0 ], AbortedReqCmd ):
298+ for obj in cmds :
299+ obj : AbortedReqCmd = obj
300+ if obj .req_id in g_infer_context .requests_mapping :
301+ req : InferReq = g_infer_context .requests_mapping [obj .req_id ]
302+ req .infer_aborted = True
303+ else :
304+ self ._init_reqs (reqs = cmds )
259305 return
260306
261307 # 一些可以复用的通用功能函数
0 commit comments