Skip to content

Commit 6136e6d

Browse files
committed
overlap router
1 parent 39e3814 commit 6136e6d

File tree

1 file changed

+32
-81
lines changed

1 file changed

+32
-81
lines changed

lightllm/server/router/manager.py

Lines changed: 32 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
108108
g_router_lock.obj = self.router_lock
109109

110110
# 调度和推理进行折叠使用的线程池
111-
self.overlap_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
112-
self.schedule_task = None
111+
# self.overlap_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
112+
# self.schedule_task = None
113113
return
114114

115115
async def wait_to_model_ready(self):
@@ -285,81 +285,34 @@ async def loop_for_fwd(
285285
if self.running_batch is None:
286286
await asyncio.sleep(0.01) # 10ms
287287

288-
async def get_schedule_result(self, running_batch: Batch):
289-
if self.schedule_task is None:
290-
_start_time = time.time()
291-
292-
def get_new_batch():
293-
if time.time() - _start_time < 0.001:
294-
time.sleep(0.003)
295-
296-
limit_router_queue_length = None
297-
if self.is_multinode_tp:
298-
# 使用 all_reduce 获取最小值
299-
limit_router_queue_length = len(self.req_queue.waiting_req_list)
300-
limit_router_queue_length_tensor = torch.tensor(
301-
limit_router_queue_length, dtype=torch.int32, device="cpu"
302-
)
303-
dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
304-
limit_router_queue_length = limit_router_queue_length_tensor.item()
305-
306-
new_batch = self.req_queue.generate_new_batch(running_batch, limit_router_queue_length)
307-
return new_batch
308-
309-
self.schedule_task = asyncio.get_running_loop().run_in_executor(self.overlap_thread_pool, get_new_batch)
310-
return None
311-
else:
312-
result = await self.schedule_task
313-
self.schedule_task = None
314-
return result
288+
def get_new_batch(self):
289+
limit_router_queue_length = None
290+
if self.is_multinode_tp:
291+
# 使用 all_reduce 获取最小值
292+
limit_router_queue_length = len(self.req_queue.waiting_req_list)
293+
limit_router_queue_length_tensor = torch.tensor(limit_router_queue_length, dtype=torch.int32, device="cpu")
294+
dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
295+
limit_router_queue_length = limit_router_queue_length_tensor.item()
296+
297+
new_batch = self.req_queue.generate_new_batch(self.running_batch, limit_router_queue_length)
298+
return new_batch
315299

316300
async def _step(self):
317301
"""
318302
事件处理循环
319303
"""
320304
# 删除所有已经 finished 的 req
321305
# 当前无运行请求时
322-
if self.running_batch is None:
323-
new_batch: Batch = await self.get_schedule_result(self.running_batch)
324-
if new_batch is not None:
325-
self.metric_client.histogram_observe("lightllm_batch_next_size", len(new_batch.reqs))
326-
for req in new_batch.reqs:
327-
self.metric_client.histogram_observe(
328-
"lightllm_request_queue_duration_bucket", time.time() - req.start_time
329-
)
330-
self.stats_tool.count_prompt_tokens(new_batch)
306+
new_batch = None
307+
if not self.batch_queue.empty():
308+
new_batch = self.batch_queue.get_nowait()
309+
if new_batch is not None:
310+
await self._prefill_batch(new_batch)
311+
self._filter_runing_batch()
312+
if self.running_batch is None:
331313
self.running_batch = new_batch
332-
await self._prefill_batch(self.running_batch)
333-
self._filter_runing_batch()
334-
335-
# 激进调度控制
336-
if not self.args.disable_aggressive_schedule:
337-
self.has_wait_tokens = self.max_wait_tokens
338-
339-
elif self.is_multinode_and_multidp:
340-
# 在多节点多 dp 的模式下,如果当前 running_batch 为None, 也需要不断的调用 decode 操作,
341-
# 因为其他节点上的dp可能存在运行的请求,所以本节点也需要调用decode,推理后端的backend会
342-
# padding 一些fake的请求来使推理过程可以正常完成。主要是给 deepseekv3 这种类型的大模型
343-
# 使用的,其ep并行模式下需要所有节点协同。
344-
await self._decode_batch(self.running_batch)
345-
346-
return
347-
348-
# 有运行请求,当持续decode的次数到达一个阈值,或者有上次预调度的结果存在的时。
349-
if self.has_wait_tokens >= self.max_wait_tokens or self.schedule_task is not None:
350-
new_mini_batch = await self.get_schedule_result(self.running_batch)
351-
self.has_wait_tokens = 0
352-
if new_mini_batch is not None:
353-
354-
# 激进调度控制
355-
if not self.args.disable_aggressive_schedule:
356-
self.has_wait_tokens = self.max_wait_tokens
357-
358-
self.stats_tool.count_prompt_tokens(new_mini_batch)
359-
await self._prefill_batch(new_mini_batch)
360-
if not new_mini_batch.is_clear():
361-
self.running_batch.merge(new_mini_batch)
362-
return
314+
else:
315+
self.running_batch.merge(new_batch)
363316

364317
# Check if need pause some requests for decode.
365318
for dp_index in range(self.dp_size_in_node):
@@ -375,36 +328,29 @@ async def _step(self):
375328
# Decode
376329
self.stats_tool.count_output_tokens(self.running_batch)
377330
await self._decode_batch(self.running_batch)
331+
if self.world_size // self.nnodes == 1:
332+
# node_world_size == 1 时,协程不会让出来,导致无法调度新请求,所以sleep 1ms,可以修改一下
333+
await asyncio.sleep(0.001)
378334
self._filter_runing_batch()
379335
self.has_wait_tokens += 1
380336
return
381337

382338
async def _prefill_batch(self, batch: Batch):
383-
start_time = time.time()
384-
self.metric_client.counter_inc("lightllm_batch_inference_count", "prefill")
385339
reqs = [r.to_router_rpc_obj() for r in batch.reqs]
386340
await self.model_rpc_client.prefill(reqs)
387341
batch.filter_out_finished_req(self.shm_req_manager)
388342
self._send_detokenization_pack()
389343

390344
logger.debug(f"Prefill Batch: {batch.simple_log()} \n")
391-
self.metric_client.histogram_observe(
392-
"lightllm_batch_inference_duration_bucket", time.time() - start_time, "prefill"
393-
)
394345
return
395346

396347
async def _decode_batch(self, batch: Batch):
397-
start_time = time.time()
398-
self.metric_client.counter_inc("lightllm_batch_inference_count", "decode")
399348
await self.model_rpc_client.decode()
400349
# 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。
401350
if batch is not None:
402351
batch.filter_out_finished_req(self.shm_req_manager)
403352

404353
self._send_detokenization_pack()
405-
self.metric_client.histogram_observe(
406-
"lightllm_batch_inference_duration_bucket", time.time() - start_time, "decode"
407-
)
408354
return
409355

410356
async def _pause_reqs(self, pasue_reqs):
@@ -418,7 +364,7 @@ def _filter_runing_batch(self):
418364
return
419365

420366
def _can_decode(self, batch: Batch, dp_index: int):
421-
if self.is_pd_run_mode or self.is_safe_schedule:
367+
if self.is_pd_run_mode or self.is_safe_schedule or batch is None:
422368
return True
423369
return (
424370
batch.get_batch_decode_need_tokens()[dp_index] + self.get_used_tokens(dp_index) <= self.max_total_token_num
@@ -451,7 +397,10 @@ async def loop_for_netio_req(self):
451397
else:
452398
assert False, f"Error Req Inf {recv_req}"
453399
except zmq.ZMQError:
454-
await asyncio.sleep(0.01)
400+
new_batch = self.get_new_batch()
401+
if new_batch is not None:
402+
self.batch_queue.put_nowait(new_batch)
403+
await asyncio.sleep(0.005)
455404
continue
456405

457406
def clean_up(self):
@@ -491,6 +440,8 @@ def handle_exception(loop, context):
491440
loop = asyncio.new_event_loop()
492441
loop.set_exception_handler(handle_exception)
493442
asyncio.set_event_loop(loop)
443+
router.batch_queue = asyncio.Queue()
444+
494445
loop.create_task(router.loop_for_fwd())
495446
loop.run_until_complete(router.loop_for_netio_req())
496447
return

0 commit comments

Comments
 (0)