@@ -225,17 +225,7 @@ def get_max_total_token_num(self):
225225
226226
227227class ModelRpcClient :
228- def __init__ (self , model_infer_servers : List [ModelRpcServer ], world_size , rpc_event , rpc_finished_event ):
229- # model_infer_servers 是传入的推理服务对象,但是在重构后,
230- # 单卡不使用rpc 通信的时候,里面才有真实对象,当多卡使用rpc
231- # 以后,model_infer_servers 传入的是 None 数组
232- if world_size == 1 :
233- self .model_infer_server : ModelRpcServer = model_infer_servers [0 ]
234- else :
235- self .model_infer_server : ModelRpcServer = None
236-
237- self .world_size = world_size
238- self .use_rpc = self .world_size != 1
228+ def __init__ (self , rpc_event , rpc_finished_event ):
239229 self .rpc_shm_params = RpcShmParams ()
240230 self .rpc_shm_params .create_or_link_shm ()
241231 self .rpc_shm_results = RpcShmResults ()
@@ -246,65 +236,46 @@ def __init__(self, model_infer_servers: List[ModelRpcServer], world_size, rpc_ev
246236 return
247237
248238 async def init_model (self , kvargs ):
249- if self .use_rpc :
250- self .rpc_shm_params .write_func_params ("init_model" , (kvargs ,))
251- self .rpc_event .set ()
239+ self .rpc_shm_params .write_func_params ("init_model" , (kvargs ,))
240+ self .rpc_event .set ()
252241
253- self .rpc_finished_event .wait ()
254- self .rpc_finished_event .clear ()
255- return
256- else :
257- self .model_infer_server .init_model (kvargs )
258- return
242+ self .rpc_finished_event .wait ()
243+ self .rpc_finished_event .clear ()
244+ return
259245
260246 async def prefill (self , reqs ):
261- if self .use_rpc :
262- self .rpc_shm_params .write_func_params ("prefill" , (reqs ,))
263- self .rpc_event .set ()
247+ self .rpc_shm_params .write_func_params ("prefill" , (reqs ,))
248+ self .rpc_event .set ()
264249
265- await asyncio .to_thread (self .rpc_finished_event .wait )
266- self .rpc_finished_event .clear ()
267- return
268- else :
269- self .model_infer_server .prefill (reqs )
270- return
250+ await asyncio .to_thread (self .rpc_finished_event .wait )
251+ self .rpc_finished_event .clear ()
252+ return
271253
272254 async def decode (self ):
273- if self .use_rpc :
274- self .rpc_shm_params .write_func_params ("decode" , ())
275- self .rpc_event .set ()
255+ self .rpc_shm_params .write_func_params ("decode" , ())
256+ self .rpc_event .set ()
276257
277- await asyncio .to_thread (self .rpc_finished_event .wait )
278- self .rpc_finished_event .clear ()
279- return
280- else :
281- self .model_infer_server .decode ()
282- return
258+ await asyncio .to_thread (self .rpc_finished_event .wait )
259+ self .rpc_finished_event .clear ()
260+ return
283261
284262 async def pause_reqs (self , req_ids ):
285- if self .use_rpc :
286- self .rpc_shm_params .write_func_params ("pause_reqs" , (req_ids ,))
287- self .rpc_event .set ()
263+ self .rpc_shm_params .write_func_params ("pause_reqs" , (req_ids ,))
264+ self .rpc_event .set ()
288265
289- self .rpc_finished_event .wait ()
290- self .rpc_finished_event .clear ()
291- return
292- else :
293- self .model_infer_server .pause_reqs (req_ids )
294- return
266+ self .rpc_finished_event .wait ()
267+ self .rpc_finished_event .clear ()
268+ return
295269
296270 async def get_max_total_token_num (self ):
297- if self .use_rpc :
298- self .rpc_shm_params .write_func_params ("get_max_total_token_num" , ())
299- self .rpc_event .set ()
300-
301- self .rpc_finished_event .wait ()
302- self .rpc_finished_event .clear ()
303- func_name , ret = self .rpc_shm_results .read_func_result ()
304- assert func_name == "get_max_total_token_num"
305- return ret
306- else :
307- return self .model_infer_server .get_max_total_token_num ()
271+ self .rpc_shm_params .write_func_params ("get_max_total_token_num" , ())
272+ self .rpc_event .set ()
273+
274+ self .rpc_finished_event .wait ()
275+ self .rpc_finished_event .clear ()
276+ func_name , ret = self .rpc_shm_results .read_func_result ()
277+ assert func_name == "get_max_total_token_num"
278+ return ret
308279
309280
310281def _init_env (
@@ -352,19 +323,6 @@ async def start_model_process(
352323):
353324 import lightllm .utils .rpyc_fix_utils as _
354325
355- # 单卡单机时不使用 rpc
356- if node_world_size == 1 and args .nnodes == 1 :
357- return ModelRpcServer (
358- args ,
359- rank ,
360- rank_in_node ,
361- node_world_size ,
362- rpc_event ,
363- rpc_finished_event ,
364- info_queue ,
365- mem_queue ,
366- )
367-
368326 success_event = mp .Event ()
369327 proc = mp .Process (
370328 target = _init_env ,
0 commit comments