@@ -244,17 +244,19 @@ async def loop_for_fwd(
244244 estimated_peak_token_count = self .shared_token_load .get_estimated_peak_token_count (d_i )
245245 logger .debug (
246246 f"dp_i { d_i } current batch size: { len (self .running_batch .reqs )} \n "
247- f"dp_i { d_i } paused req num: { self .req_queue .get_paused_req_num ()} \n "
247+ f"dp_i { d_i } paused req num: { self .req_queue .get_paused_req_num (d_i )} \n "
248248 f"dp_i { d_i } frozen token num: { frozen_token_num } \n "
249249 f"dp_i { d_i } estimated_peak_token_count: { estimated_peak_token_count } \n "
250250 f"dp_i { d_i } token used ratio: { token_ratio1 } not contain prompt cache tree unrefed token\n "
251251 f"dp_i { d_i } token used ratio: { token_ratio2 } contain prompt cache tree unrefed token"
252252 )
253+ self .metric_client .gauge_set (
254+ "lightllm_batch_pause_size" , self .req_queue .get_paused_req_num (d_i )
255+ )
253256 # pd decode mode need to update token_load more frequently
254257 self .req_queue .update_token_load (self .running_batch , force_update = self .is_pd_decode_mode )
255258 self .stats_tool .print_stats ()
256259 self .metric_client .gauge_set ("lightllm_batch_current_size" , len (self .running_batch .reqs ))
257- self .metric_client .gauge_set ("lightllm_batch_pause_size" , self .req_queue .get_paused_req_num ())
258260 self .metric_client .gauge_set ("lightllm_queue_size" , self .req_queue .get_wait_req_num ())
259261 self .metric_client .gauge_set (
260262 "lightllm_batch_current_max_tokens" ,
@@ -356,23 +358,22 @@ async def _step(self):
356358 self .running_batch .merge (new_mini_batch )
357359 return
358360
359- # 正常 decode 阶段, 如果可以直接decode就直接decode,否则通过暂停策略暂停一些请求
360- # 释放一些管理的 token
361- if self ._can_decode (self .running_batch ):
362- self .stats_tool .count_output_tokens (self .running_batch )
363- await self ._decode_batch (self .running_batch )
364- self ._filter_runing_batch ()
365- self .has_wait_tokens += 1
366- return
367- else :
368- # pause strategy
369- paused_reqs = select_paused_reqs (
370- self .running_batch , self .pause_strategy , self .req_queue , self .max_total_token_num
371- )
372- await self ._pause_reqs (paused_reqs )
373- logger .debug (f"pasued req num: { self .req_queue .get_paused_req_num ()} " )
374- self .has_wait_tokens = 0
375- return
361+ # Check if need pause some requests for decode.
362+ for dp_index in range (self .dp_size_in_node ):
363+ while not self ._can_decode (self .running_batch , dp_index = dp_index ):
364+ # pause strategy
365+ paused_reqs = select_paused_reqs (
366+ self .running_batch , self .pause_strategy , self .req_queue , self .max_total_token_num , dp_index = dp_index
367+ )
368+ await self ._pause_reqs (paused_reqs )
369+ logger .debug (f"DP index { dp_index } pasues req num: { self .req_queue .get_paused_req_num (dp_index )} " )
370+ self .has_wait_tokens = 0
371+
372+ # Decode
373+ self .stats_tool .count_output_tokens (self .running_batch )
374+ await self ._decode_batch (self .running_batch )
375+ self ._filter_runing_batch ()
376+ self .has_wait_tokens += 1
376377 return
377378
378379 async def _prefill_batch (self , batch : Batch ):
@@ -416,16 +417,12 @@ def _filter_runing_batch(self):
416417 self .running_batch = None
417418 return
418419
419- def _can_decode (self , batch : Batch ):
420- # p d 分离模式下,目前只能使用保守调度,保证请求放入进行decode的时候
421- # 显存token肯定是够用的。
422- # deepseekv2 dp 模式下,采用保守调度,也肯定够用
423- if self .is_pd_run_mode or self .dp_size_in_node > 1 or self .is_safe_schedule :
420+ def _can_decode (self , batch : Batch , dp_index : int ):
421+ if self .is_pd_run_mode or self .is_safe_schedule :
424422 return True
425-
426- # 下面的判定条件,只在 dp 为 1 的情况下启用
427- assert self .dp_size_in_node == 1
428- return batch .get_batch_decode_need_tokens ()[0 ] + self .get_used_tokens (0 ) <= self .max_total_token_num
423+ return (
424+ batch .get_batch_decode_need_tokens ()[dp_index ] + self .get_used_tokens (dp_index ) <= self .max_total_token_num
425+ )
429426
430427 def get_used_tokens (self , dp_index ):
431428 if self .args .use_dynamic_prompt_cache :
0 commit comments