Skip to content

Commit 9f842ac

Browse files
committed
fix
1 parent f5aca73 commit 9f842ac

File tree

1 file changed

+42
-35
lines changed

1 file changed

+42
-35
lines changed

lightllm/server/router/manager.py

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
1-
import copy
21
import time
3-
import uuid
42
import uvloop
53
import asyncio
64
import torch
7-
import rpyc
85
import pickle
9-
import threading
106
import inspect
117

128
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
13-
import concurrent.futures
149
import zmq
1510
import zmq.asyncio
1611
import torch.multiprocessing as mp
1712
import torch.distributed as dist
1813
import multiprocessing
1914
from typing import Dict, List, Optional
20-
from .batch import Batch
15+
from .batch import Batch, Req
2116
from .model_infer.model_rpc import start_model_process, ModelRpcClient
2217
from .req_queue import build_req_queue
23-
from lightllm.utils.infer_utils import calculate_time
2418
from lightllm.server.core.objs.io_objs import GroupReqIndexes
2519
from lightllm.server.core.objs import ShmReqManager, StartArgs
2620
from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient
@@ -109,8 +103,7 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
109103

110104
# 调度和推理进行折叠使用的线程池
111105
self.schedule_new_batch: Batch = None
112-
self.schedule_lock = asyncio.Lock()
113-
self.schedule_sem = asyncio.Semaphore(1)
106+
self.schedule_event = asyncio.Event()
114107
return
115108

116109
async def wait_to_model_ready(self):
@@ -222,7 +215,6 @@ def add_req(self, group_req_indexes: GroupReqIndexes):
222215
logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s")
223216
self.req_queue.extend(req_group)
224217
self.send_to_detokenization.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
225-
226218
return
227219

228220
async def loop_for_fwd(
@@ -312,14 +304,10 @@ async def _step(self):
312304
):
313305
new_batch = self.schedule_new_batch
314306
self.schedule_new_batch = None
307+
self._add_new_batch_to_running_batch(new_batch=new_batch)
315308
await self._prefill_batch(new_batch)
316309
self.stats_tool.count_prompt_tokens(new_batch)
317-
self._filter_runing_batch()
318-
if not new_batch.is_clear():
319-
if self.running_batch is None:
320-
self.running_batch = new_batch
321-
else:
322-
self.running_batch.merge(new_batch)
310+
self._filter_reqs_from_running_batch()
323311
self.has_wait_tokens = 0
324312

325313
# Check if need pause some requests for decode.
@@ -335,38 +323,43 @@ async def _step(self):
335323

336324
# Decode
337325
self.stats_tool.count_output_tokens(self.running_batch)
338-
await self._decode_batch(self.running_batch)
339-
self._filter_runing_batch()
326+
await self._decode_batch()
327+
self._filter_reqs_from_running_batch()
340328
self.has_wait_tokens += 1
341329
return
342330

343331
async def _prefill_batch(self, batch: Batch):
332+
# 添加新请求
344333
reqs = [r.to_router_rpc_obj() for r in batch.reqs]
345334
await self.model_rpc_client.prefill(reqs)
346-
batch.filter_out_finished_req(self.shm_req_manager)
347335
self._send_detokenization_pack()
348-
349336
logger.debug(f"Prefill Batch: {batch.simple_log()} \n")
350337
return
351338

352-
async def _decode_batch(self, batch: Batch):
339+
async def _decode_batch(self):
340+
self.schedule_event.set()
353341
await self.model_rpc_client.decode()
354-
# 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。
355-
if batch is not None:
356-
batch.filter_out_finished_req(self.shm_req_manager)
357-
358342
self._send_detokenization_pack()
359343
return
360344

361-
async def _pause_reqs(self, pasue_reqs):
345+
async def _pause_reqs(self, pasue_reqs: List[Req]):
362346
pasue_req_ids = [r.request_id for r in pasue_reqs]
363347
await self.model_rpc_client.pause_reqs(pasue_req_ids)
364348
return
365349

366-
def _filter_runing_batch(self):
367-
if self.running_batch is not None and self.running_batch.is_clear():
368-
self.running_batch = None
369-
return
350+
def _add_new_batch_to_running_batch(self, new_batch: Batch):
351+
if self.running_batch is None:
352+
self.running_batch = new_batch
353+
else:
354+
self.running_batch.merge(new_batch)
355+
return
356+
357+
def _filter_reqs_from_running_batch(self):
358+
if self.running_batch is not None:
359+
self.running_batch.filter_out_finished_req(self.shm_req_manager)
360+
if self.running_batch.is_clear():
361+
self.running_batch = None
362+
return
370363

371364
def _can_decode(self, batch: Batch, dp_index: int):
372365
if self.is_pd_run_mode or self.is_safe_schedule or batch is None:
@@ -394,21 +387,35 @@ def get_used_tokens(self, dp_index):
394387
return self.max_total_token_num - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index)
395388

396389
async def loop_for_netio_req(self):
390+
recv_max_count = 66
391+
397392
while True:
398393
try:
399-
# 一次最多从 zmq 中取 24 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
400-
for _ in range(36):
394+
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
395+
for _ in range(recv_max_count):
401396
recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK)
402397
if isinstance(recv_req, GroupReqIndexes):
403398
self.add_req(recv_req)
404399
else:
405400
assert False, f"Error Req Inf {recv_req}"
401+
402+
# 当队列中存在较多的请求时,将一次接受的数量上调
403+
recv_max_count = min(int(recv_max_count * 1.3), 300)
404+
406405
except zmq.ZMQError:
406+
# 当队列已经开始清空的时候,将一次接受的数量下调
407+
recv_max_count = 66
408+
409+
try:
410+
await asyncio.wait_for(self.schedule_event.wait(), timeout=0.02)
411+
except asyncio.TimeoutError:
407412
pass
408413

409-
# 调度新的 batch
410-
self.generate_new_batch()
411-
await asyncio.sleep(0.005)
414+
if self.schedule_event.is_set():
415+
self.generate_new_batch()
416+
self.schedule_event.clear()
417+
418+
return
412419

413420
def clean_up(self):
414421
return

0 commit comments

Comments
 (0)