Skip to content

Commit ea0ada4

Browse files
committed
update router
1 parent 74cfa55 commit ea0ada4

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

lightllm/server/router/manager.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,6 @@ async def wait_to_model_ready(self):
197197
return
198198

199199
def _get_schedule_time_interval(self):
200-
if self.running_batch is None:
201-
# 没有运行中的 batch 时,每 10ms 触发一次请求调度
202-
return 0.01
203-
204200
# dp 模式,为了更好的配平,需要更长的调度间隔,以便于能收到更多的请求
205201
return self.schedule_time_interval
206202

@@ -370,9 +366,7 @@ def _add_req(self, group_req_indexes: GroupReqIndexes):
370366

371367
def _generate_new_batch(self):
372368
# 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
373-
new_batch = self.req_queue.generate_new_batch(
374-
Batch.merge_two_batch(self.running_batch, self.schedule_new_batch)
375-
)
369+
new_batch = self.req_queue.generate_new_batch(self.schedule_new_batch)
376370
self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch)
377371
return
378372

@@ -469,7 +463,7 @@ async def _recv_new_reqs_and_schedule(self):
469463
if self.is_multinode_tp:
470464
self._multinode_tp_generate_new_batch()
471465
else:
472-
if self._get_paused_req_num() == 0:
466+
if self._get_paused_req_num() == 0 and self.shm_reqs_io_buffer.is_empty():
473467
self._generate_new_batch()
474468
return
475469

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def init_model(self, kvargs):
7575
self.chunked_prefill_size = self.args.chunked_prefill_size
7676
self.return_all_prompt_logprobs = self.args.return_all_prompt_logprobs
7777
self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache
78+
self.batch_max_tokens = self.args.batch_max_tokens
7879
self.eos_id: List[int] = kvargs.get("eos_id", [2])
7980
self.disable_cudagraph = self.args.disable_cudagraph
8081
self.is_multinode_tp = self.args.nnodes > 1 and self.args.dp == 1
@@ -391,6 +392,7 @@ def _get_classed_reqs(
391392
# 请求,其逻辑是不适合的。
392393
pause_max_req_num = 2
393394
wait_pause_count = 0
395+
prefill_tokens = 0
394396

395397
# 因为会使用到 radix cache 和 mem_manager 的计数信息
396398
# 所以需要加锁保护。
@@ -439,6 +441,11 @@ def _get_classed_reqs(
439441
wait_pause_count += 1
440442
else:
441443
token_num = req_obj.prefill_need_token_num(is_chuncked_prefill=not self.disable_chunked_prefill)
444+
if prefill_tokens + token_num > self.batch_max_tokens:
445+
# 跳过等下次prefill,避免oom
446+
prefill_tokens = 0
447+
break
448+
prefill_tokens += token_num
442449
if token_num <= can_alloc_token_num:
443450
prefill_reqs.append(req_obj)
444451
can_alloc_token_num -= token_num

lightllm/server/router/req_queue/chunked_prefill/impl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def generate_new_batch(self, current_batch: Batch):
6969
new_batch_first_router_need_tokens = (
7070
0 if current_batch is None else current_batch.get_batch_decode_need_tokens()[self.dp_index]
7171
)
72+
print(f"new_batch_first_router_need_tokens: {new_batch_first_router_need_tokens}")
7273

7374
self._init_cache_list(current_batch, is_busy)
7475
can_run_list = []

0 commit comments

Comments
 (0)