@@ -32,9 +32,9 @@ class ModelRpcServer:
3232 def __init__ (
3333 self ,
3434 args ,
35- tp_rank : int ,
36- local_tp_rank : int ,
37- world_size : int ,
35+ tp_rank : int ,
36+ local_tp_rank : int ,
37+ world_size : int ,
3838 local_world_size : int ,
3939 rpc_event : multiprocessing .Event ,
4040 rpc_finished_event : multiprocessing .Event ,
@@ -286,7 +286,9 @@ def _init_env(
286286
287287 g_router_lock .obj = router_lock
288288
289- model_rpc_server = ModelRpcServer (args , tp_rank , local_tp_rank , world_size , local_world_size , rpc_event , rpc_finished_event , info_queue , mem_queue )
289+ model_rpc_server = ModelRpcServer (
290+ args , tp_rank , local_tp_rank , world_size , local_world_size , rpc_event , rpc_finished_event , info_queue , mem_queue
291+ )
290292 success_event .set ()
291293
292294 model_rpc_server .loop_thread .join ()
@@ -309,12 +311,34 @@ async def start_model_process(
309311
310312 # 单卡时不使用 rpc
311313 if world_size == 1 :
312- return ModelRpcServer (args , tp_rank , local_tp_rank , world_size , local_world_size , rpc_event , rpc_finished_event , info_queue , mem_queue )
314+ return ModelRpcServer (
315+ args ,
316+ tp_rank ,
317+ local_tp_rank ,
318+ world_size ,
319+ local_world_size ,
320+ rpc_event ,
321+ rpc_finished_event ,
322+ info_queue ,
323+ mem_queue ,
324+ )
313325
314326 success_event = mp .Event ()
315327 proc = mp .Process (
316328 target = _init_env ,
317- args = (args , tp_rank , local_tp_rank , world_size , local_world_size , info_queue , mem_queue , router_lock , rpc_event , rpc_finished_event , success_event ),
329+ args = (
330+ args ,
331+ tp_rank ,
332+ local_tp_rank ,
333+ world_size ,
334+ local_world_size ,
335+ info_queue ,
336+ mem_queue ,
337+ router_lock ,
338+ rpc_event ,
339+ rpc_finished_event ,
340+ success_event ,
341+ ),
318342 )
319343 proc .start ()
320344 success_event .wait (timeout = 40 )
0 commit comments