File tree Expand file tree Collapse file tree 3 files changed +10
-8
lines changed
req_queue/chunked_prefill Expand file tree Collapse file tree 3 files changed +10
-8
lines changed Original file line number Diff line number Diff line change @@ -197,10 +197,6 @@ async def wait_to_model_ready(self):
197197 return
198198
199199 def _get_schedule_time_interval (self ):
200- if self .running_batch is None :
201- # 没有运行中的 batch 时,每 10ms 触发一次请求调度
202- return 0.01
203-
204200 # dp 模式,为了更好的配平,需要更长的调度间隔,以便于能收到更多的请求
205201 return self .schedule_time_interval
206202
@@ -370,9 +366,7 @@ def _add_req(self, group_req_indexes: GroupReqIndexes):
370366
371367 def _generate_new_batch (self ):
372368 # 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
373- new_batch = self .req_queue .generate_new_batch (
374- Batch .merge_two_batch (self .running_batch , self .schedule_new_batch )
375- )
369+ new_batch = self .req_queue .generate_new_batch (self .schedule_new_batch )
376370 self .schedule_new_batch = Batch .merge_two_batch (self .schedule_new_batch , new_batch )
377371 return
378372
@@ -469,7 +463,7 @@ async def _recv_new_reqs_and_schedule(self):
469463 if self .is_multinode_tp :
470464 self ._multinode_tp_generate_new_batch ()
471465 else :
472- if self ._get_paused_req_num () == 0 :
466+ if self ._get_paused_req_num () == 0 and self . shm_reqs_io_buffer . is_empty () :
473467 self ._generate_new_batch ()
474468 return
475469
Original file line number Diff line number Diff line change @@ -75,6 +75,7 @@ def init_model(self, kvargs):
7575 self .chunked_prefill_size = self .args .chunked_prefill_size
7676 self .return_all_prompt_logprobs = self .args .return_all_prompt_logprobs
7777 self .use_dynamic_prompt_cache = not self .args .disable_dynamic_prompt_cache
78+ self .batch_max_tokens = self .args .batch_max_tokens
7879 self .eos_id : List [int ] = kvargs .get ("eos_id" , [2 ])
7980 self .disable_cudagraph = self .args .disable_cudagraph
8081 self .is_multinode_tp = self .args .nnodes > 1 and self .args .dp == 1
@@ -391,6 +392,7 @@ def _get_classed_reqs(
391392 # 请求,其逻辑是不适合的。
392393 pause_max_req_num = 2
393394 wait_pause_count = 0
395+ prefill_tokens = 0
394396
395397 # 因为会使用到 radix cache 和 mem_manager 的计数信息
396398 # 所以需要加锁保护。
@@ -439,6 +441,11 @@ def _get_classed_reqs(
439441 wait_pause_count += 1
440442 else :
441443 token_num = req_obj .prefill_need_token_num (is_chuncked_prefill = not self .disable_chunked_prefill )
444+ if prefill_tokens + token_num > self .batch_max_tokens :
445+ # 跳过等下次prefill,避免oom
446+ prefill_tokens = 0
447+ break
448+ prefill_tokens += token_num
442449 if token_num <= can_alloc_token_num :
443450 prefill_reqs .append (req_obj )
444451 can_alloc_token_num -= token_num
Original file line number Diff line number Diff line change @@ -69,6 +69,7 @@ def generate_new_batch(self, current_batch: Batch):
6969 new_batch_first_router_need_tokens = (
7070 0 if current_batch is None else current_batch .get_batch_decode_need_tokens ()[self .dp_index ]
7171 )
72+ print (f"new_batch_first_router_need_tokens: { new_batch_first_router_need_tokens } " )
7273
7374 self ._init_cache_list (current_batch , is_busy )
7475 can_run_list = []
You can’t perform that action at this time.
0 commit comments