Skip to content

Commit 95d7dda

Browse files
committed
fix
1 parent 7d015dd commit 95d7dda

File tree

2 files changed

+29
-73
lines changed

2 files changed

+29
-73
lines changed

lightllm/server/router/manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,6 @@ async def wait_to_model_ready(self):
141141
self.model_rpc_servers.append(rpc_model)
142142

143143
self.model_rpc_client = ModelRpcClient(
144-
model_infer_servers=self.model_rpc_servers,
145-
world_size=self.world_size,
146144
rpc_event=self.rpc_event,
147145
rpc_finished_event=self.rpc_finished_event,
148146
)

lightllm/server/router/model_infer/model_rpc.py

Lines changed: 29 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -225,17 +225,7 @@ def get_max_total_token_num(self):
225225

226226

227227
class ModelRpcClient:
228-
def __init__(self, model_infer_servers: List[ModelRpcServer], world_size, rpc_event, rpc_finished_event):
229-
# model_infer_servers 是传入的推理服务对象,但是在重构后,
230-
# 单卡不使用rpc 通信的时候,里面才有真实对象,当多卡使用rpc
231-
# 以后,model_infer_servers 传入的是 None 数组
232-
if world_size == 1:
233-
self.model_infer_server: ModelRpcServer = model_infer_servers[0]
234-
else:
235-
self.model_infer_server: ModelRpcServer = None
236-
237-
self.world_size = world_size
238-
self.use_rpc = self.world_size != 1
228+
def __init__(self, rpc_event, rpc_finished_event):
239229
self.rpc_shm_params = RpcShmParams()
240230
self.rpc_shm_params.create_or_link_shm()
241231
self.rpc_shm_results = RpcShmResults()
@@ -246,65 +236,46 @@ def __init__(self, model_infer_servers: List[ModelRpcServer], world_size, rpc_ev
246236
return
247237

248238
async def init_model(self, kvargs):
249-
if self.use_rpc:
250-
self.rpc_shm_params.write_func_params("init_model", (kvargs,))
251-
self.rpc_event.set()
239+
self.rpc_shm_params.write_func_params("init_model", (kvargs,))
240+
self.rpc_event.set()
252241

253-
self.rpc_finished_event.wait()
254-
self.rpc_finished_event.clear()
255-
return
256-
else:
257-
self.model_infer_server.init_model(kvargs)
258-
return
242+
self.rpc_finished_event.wait()
243+
self.rpc_finished_event.clear()
244+
return
259245

260246
async def prefill(self, reqs):
261-
if self.use_rpc:
262-
self.rpc_shm_params.write_func_params("prefill", (reqs,))
263-
self.rpc_event.set()
247+
self.rpc_shm_params.write_func_params("prefill", (reqs,))
248+
self.rpc_event.set()
264249

265-
await asyncio.to_thread(self.rpc_finished_event.wait)
266-
self.rpc_finished_event.clear()
267-
return
268-
else:
269-
self.model_infer_server.prefill(reqs)
270-
return
250+
await asyncio.to_thread(self.rpc_finished_event.wait)
251+
self.rpc_finished_event.clear()
252+
return
271253

272254
async def decode(self):
273-
if self.use_rpc:
274-
self.rpc_shm_params.write_func_params("decode", ())
275-
self.rpc_event.set()
255+
self.rpc_shm_params.write_func_params("decode", ())
256+
self.rpc_event.set()
276257

277-
await asyncio.to_thread(self.rpc_finished_event.wait)
278-
self.rpc_finished_event.clear()
279-
return
280-
else:
281-
self.model_infer_server.decode()
282-
return
258+
await asyncio.to_thread(self.rpc_finished_event.wait)
259+
self.rpc_finished_event.clear()
260+
return
283261

284262
async def pause_reqs(self, req_ids):
285-
if self.use_rpc:
286-
self.rpc_shm_params.write_func_params("pause_reqs", (req_ids,))
287-
self.rpc_event.set()
263+
self.rpc_shm_params.write_func_params("pause_reqs", (req_ids,))
264+
self.rpc_event.set()
288265

289-
self.rpc_finished_event.wait()
290-
self.rpc_finished_event.clear()
291-
return
292-
else:
293-
self.model_infer_server.pause_reqs(req_ids)
294-
return
266+
self.rpc_finished_event.wait()
267+
self.rpc_finished_event.clear()
268+
return
295269

296270
async def get_max_total_token_num(self):
297-
if self.use_rpc:
298-
self.rpc_shm_params.write_func_params("get_max_total_token_num", ())
299-
self.rpc_event.set()
300-
301-
self.rpc_finished_event.wait()
302-
self.rpc_finished_event.clear()
303-
func_name, ret = self.rpc_shm_results.read_func_result()
304-
assert func_name == "get_max_total_token_num"
305-
return ret
306-
else:
307-
return self.model_infer_server.get_max_total_token_num()
271+
self.rpc_shm_params.write_func_params("get_max_total_token_num", ())
272+
self.rpc_event.set()
273+
274+
self.rpc_finished_event.wait()
275+
self.rpc_finished_event.clear()
276+
func_name, ret = self.rpc_shm_results.read_func_result()
277+
assert func_name == "get_max_total_token_num"
278+
return ret
308279

309280

310281
def _init_env(
@@ -352,19 +323,6 @@ async def start_model_process(
352323
):
353324
import lightllm.utils.rpyc_fix_utils as _
354325

355-
# 单卡单机时不使用 rpc
356-
if node_world_size == 1 and args.nnodes == 1:
357-
return ModelRpcServer(
358-
args,
359-
rank,
360-
rank_in_node,
361-
node_world_size,
362-
rpc_event,
363-
rpc_finished_event,
364-
info_queue,
365-
mem_queue,
366-
)
367-
368326
success_event = mp.Event()
369327
proc = mp.Process(
370328
target=_init_env,

0 commit comments

Comments
 (0)