Skip to content

Commit 72061fb

Browse files
committed
fix
1 parent f36a0e4 commit 72061fb

File tree

2 files changed

+34
-26
lines changed

2 files changed

+34
-26
lines changed

lightllm/server/router/manager.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,9 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
108108
g_router_lock.obj = self.router_lock
109109

110110
# 调度和推理进行折叠使用的线程池
111-
# self.overlap_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
112-
# self.schedule_task = None
111+
self.schedule_new_batch : Batch = None
112+
self.schedule_lock = asyncio.Lock()
113+
self.schedule_sem = asyncio.Semaphore(1)
113114
return
114115

115116
async def wait_to_model_ready(self):
@@ -285,34 +286,37 @@ async def loop_for_fwd(
285286
if self.running_batch is None:
286287
await asyncio.sleep(0.01) # 10ms
287288

288-
def get_new_batch(self):
289+
def generate_new_batch(self):
289290
limit_router_queue_length = None
290291
if self.is_multinode_tp:
291292
# 使用 all_reduce 获取最小值
292293
limit_router_queue_length = len(self.req_queue.waiting_req_list)
293294
limit_router_queue_length_tensor = torch.tensor(limit_router_queue_length, dtype=torch.int32, device="cpu")
294295
dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
295296
limit_router_queue_length = limit_router_queue_length_tensor.item()
296-
297-
new_batch = self.req_queue.generate_new_batch(self.running_batch, limit_router_queue_length)
298-
return new_batch
297+
298+
# 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
299+
new_batch = self.req_queue.generate_new_batch(Batch.merge(self.running_batch, self.schedule_new_batch), limit_router_queue_length)
300+
self.schedule_new_batch = Batch.merge(self.schedule_new_batch, new_batch)
301+
return
299302

300303
async def _step(self):
301304
"""
302305
事件处理循环
303306
"""
304307
# 删除所有已经 finished 的 req
305308
# 当前无运行请求时
306-
new_batch = None
307-
if not self.batch_queue.empty():
308-
new_batch = self.batch_queue.get_nowait()
309+
new_batch = self.schedule_new_batch
310+
self.schedule_new_batch = None
309311
if new_batch is not None:
310312
await self._prefill_batch(new_batch)
313+
self.stats_tool.count_prompt_tokens(new_batch)
311314
self._filter_runing_batch()
312-
if self.running_batch is None:
313-
self.running_batch = new_batch
314-
else:
315-
self.running_batch.merge(new_batch)
315+
if not new_batch.is_clear():
316+
if self.running_batch is None:
317+
self.running_batch = new_batch
318+
else:
319+
self.running_batch.merge(new_batch)
316320

317321
# Check if need pause some requests for decode.
318322
for dp_index in range(self.dp_size_in_node):
@@ -391,17 +395,20 @@ def get_used_tokens(self, dp_index):
391395
async def loop_for_netio_req(self):
392396
while True:
393397
try:
394-
recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK)
395-
if isinstance(recv_req, GroupReqIndexes):
396-
self.add_req(recv_req)
397-
else:
398-
assert False, f"Error Req Inf {recv_req}"
398+
# 一次最多从 zmq 中取 24 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
399+
for _ in range(36):
400+
recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK)
401+
if isinstance(recv_req, GroupReqIndexes):
402+
self.add_req(recv_req)
403+
else:
404+
assert False, f"Error Req Inf {recv_req}"
399405
except zmq.ZMQError:
400-
new_batch = self.get_new_batch()
401-
if new_batch is not None:
402-
self.batch_queue.put_nowait(new_batch)
403-
await asyncio.sleep(0.005)
404-
continue
406+
pass
407+
408+
# 调度新的 batch
409+
410+
self.generate_new_batch()
411+
await asyncio.sleep(0.005)
405412

406413
def clean_up(self):
407414
return
@@ -440,8 +447,8 @@ def handle_exception(loop, context):
440447
loop = asyncio.new_event_loop()
441448
loop.set_exception_handler(handle_exception)
442449
asyncio.set_event_loop(loop)
443-
router.batch_queue = asyncio.Queue()
444450

445451
loop.create_task(router.loop_for_fwd())
446452
loop.run_until_complete(router.loop_for_netio_req())
447453
return
454+

lightllm/server/router/stats.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import time
22
from lightllm.utils.log_utils import init_logger
3+
from .batch import Batch
34

45
logger = init_logger(__name__)
56

@@ -14,14 +15,14 @@ def __init__(self, log_status, log_stats_interval) -> None:
1415
self.prompt_tokens = 0
1516
return
1617

17-
def count_prompt_tokens(self, run_batch):
18+
def count_prompt_tokens(self, run_batch: Batch):
1819
if self.log_stats:
1920
tokens = run_batch.input_tokens()
2021
self.prompt_tokens += tokens
2122
self.all_tokens += tokens
2223
return
2324

24-
def count_output_tokens(self, run_batch):
25+
def count_output_tokens(self, run_batch: Batch):
2526
if self.log_stats:
2627
tokens = len(run_batch.reqs)
2728
self.output_tokens += tokens

0 commit comments

Comments
 (0)