@@ -108,8 +108,8 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
108108 g_router_lock .obj = self .router_lock
109109
110110 # 调度和推理进行折叠使用的线程池
111- self .overlap_thread_pool = concurrent .futures .ThreadPoolExecutor (max_workers = 1 )
112- self .schedule_task = None
111+ # self.overlap_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
112+ # self.schedule_task = None
113113 return
114114
115115 async def wait_to_model_ready (self ):
@@ -285,81 +285,34 @@ async def loop_for_fwd(
285285 if self .running_batch is None :
286286 await asyncio .sleep (0.01 ) # 10ms
287287
288- async def get_schedule_result (self , running_batch : Batch ):
289- if self .schedule_task is None :
290- _start_time = time .time ()
291-
292- def get_new_batch ():
293- if time .time () - _start_time < 0.001 :
294- time .sleep (0.003 )
295-
296- limit_router_queue_length = None
297- if self .is_multinode_tp :
298- # 使用 all_reduce 获取最小值
299- limit_router_queue_length = len (self .req_queue .waiting_req_list )
300- limit_router_queue_length_tensor = torch .tensor (
301- limit_router_queue_length , dtype = torch .int32 , device = "cpu"
302- )
303- dist .all_reduce (limit_router_queue_length_tensor , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
304- limit_router_queue_length = limit_router_queue_length_tensor .item ()
305-
306- new_batch = self .req_queue .generate_new_batch (running_batch , limit_router_queue_length )
307- return new_batch
308-
309- self .schedule_task = asyncio .get_running_loop ().run_in_executor (self .overlap_thread_pool , get_new_batch )
310- return None
311- else :
312- result = await self .schedule_task
313- self .schedule_task = None
314- return result
288+ def get_new_batch (self ):
289+ limit_router_queue_length = None
290+ if self .is_multinode_tp :
291+ # 使用 all_reduce 获取最小值
292+ limit_router_queue_length = len (self .req_queue .waiting_req_list )
293+ limit_router_queue_length_tensor = torch .tensor (limit_router_queue_length , dtype = torch .int32 , device = "cpu" )
294+ dist .all_reduce (limit_router_queue_length_tensor , op = dist .ReduceOp .MIN , group = self .mulitnode_group )
295+ limit_router_queue_length = limit_router_queue_length_tensor .item ()
296+
297+ new_batch = self .req_queue .generate_new_batch (self .running_batch , limit_router_queue_length )
298+ return new_batch
315299
316300 async def _step (self ):
317301 """
318302 事件处理循环
319303 """
320304 # 删除所有已经 finished 的 req
321305 # 当前无运行请求时
322- if self .running_batch is None :
323- new_batch : Batch = await self .get_schedule_result (self .running_batch )
324- if new_batch is not None :
325- self .metric_client .histogram_observe ("lightllm_batch_next_size" , len (new_batch .reqs ))
326- for req in new_batch .reqs :
327- self .metric_client .histogram_observe (
328- "lightllm_request_queue_duration_bucket" , time .time () - req .start_time
329- )
330- self .stats_tool .count_prompt_tokens (new_batch )
306+ new_batch = None
307+ if not self .batch_queue .empty ():
308+ new_batch = self .batch_queue .get_nowait ()
309+ if new_batch is not None :
310+ await self ._prefill_batch (new_batch )
311+ self ._filter_runing_batch ()
312+ if self .running_batch is None :
331313 self .running_batch = new_batch
332- await self ._prefill_batch (self .running_batch )
333- self ._filter_runing_batch ()
334-
335- # 激进调度控制
336- if not self .args .disable_aggressive_schedule :
337- self .has_wait_tokens = self .max_wait_tokens
338-
339- elif self .is_multinode_and_multidp :
340- # 在多节点多 dp 的模式下,如果当前 running_batch 为None, 也需要不断的调用 decode 操作,
341- # 因为其他节点上的dp可能存在运行的请求,所以本节点也需要调用decode,推理后端的backend会
342- # padding 一些fake的请求来使推理过程可以正常完成。主要是给 deepseekv3 这种类型的大模型
343- # 使用的,其ep并行模式下需要所有节点协同。
344- await self ._decode_batch (self .running_batch )
345-
346- return
347-
348- # 有运行请求,当持续decode的次数到达一个阈值,或者有上次预调度的结果存在的时。
349- if self .has_wait_tokens >= self .max_wait_tokens or self .schedule_task is not None :
350- new_mini_batch = await self .get_schedule_result (self .running_batch )
351- self .has_wait_tokens = 0
352- if new_mini_batch is not None :
353-
354- # 激进调度控制
355- if not self .args .disable_aggressive_schedule :
356- self .has_wait_tokens = self .max_wait_tokens
357-
358- self .stats_tool .count_prompt_tokens (new_mini_batch )
359- await self ._prefill_batch (new_mini_batch )
360- if not new_mini_batch .is_clear ():
361- self .running_batch .merge (new_mini_batch )
362- return
314+ else :
315+ self .running_batch .merge (new_batch )
363316
364317 # Check if need pause some requests for decode.
365318 for dp_index in range (self .dp_size_in_node ):
@@ -375,36 +328,29 @@ async def _step(self):
375328 # Decode
376329 self .stats_tool .count_output_tokens (self .running_batch )
377330 await self ._decode_batch (self .running_batch )
331+ if self .world_size // self .nnodes == 1 :
332+ # node_world_size == 1 时,协程不会让出来,导致无法调度新请求,所以sleep 1ms,可以修改一下
333+ await asyncio .sleep (0.001 )
378334 self ._filter_runing_batch ()
379335 self .has_wait_tokens += 1
380336 return
381337
382338 async def _prefill_batch (self , batch : Batch ):
383- start_time = time .time ()
384- self .metric_client .counter_inc ("lightllm_batch_inference_count" , "prefill" )
385339 reqs = [r .to_router_rpc_obj () for r in batch .reqs ]
386340 await self .model_rpc_client .prefill (reqs )
387341 batch .filter_out_finished_req (self .shm_req_manager )
388342 self ._send_detokenization_pack ()
389343
390344 logger .debug (f"Prefill Batch: { batch .simple_log ()} \n " )
391- self .metric_client .histogram_observe (
392- "lightllm_batch_inference_duration_bucket" , time .time () - start_time , "prefill"
393- )
394345 return
395346
396347 async def _decode_batch (self , batch : Batch ):
397- start_time = time .time ()
398- self .metric_client .counter_inc ("lightllm_batch_inference_count" , "decode" )
399348 await self .model_rpc_client .decode ()
400349 # 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。
401350 if batch is not None :
402351 batch .filter_out_finished_req (self .shm_req_manager )
403352
404353 self ._send_detokenization_pack ()
405- self .metric_client .histogram_observe (
406- "lightllm_batch_inference_duration_bucket" , time .time () - start_time , "decode"
407- )
408354 return
409355
410356 async def _pause_reqs (self , pasue_reqs ):
@@ -418,7 +364,7 @@ def _filter_runing_batch(self):
418364 return
419365
420366 def _can_decode (self , batch : Batch , dp_index : int ):
421- if self .is_pd_run_mode or self .is_safe_schedule :
367+ if self .is_pd_run_mode or self .is_safe_schedule or batch is None :
422368 return True
423369 return (
424370 batch .get_batch_decode_need_tokens ()[dp_index ] + self .get_used_tokens (dp_index ) <= self .max_total_token_num
@@ -451,7 +397,10 @@ async def loop_for_netio_req(self):
451397 else :
452398 assert False , f"Error Req Inf { recv_req } "
453399 except zmq .ZMQError :
454- await asyncio .sleep (0.01 )
400+ new_batch = self .get_new_batch ()
401+ if new_batch is not None :
402+ self .batch_queue .put_nowait (new_batch )
403+ await asyncio .sleep (0.005 )
455404 continue
456405
457406 def clean_up (self ):
@@ -491,6 +440,8 @@ def handle_exception(loop, context):
491440 loop = asyncio .new_event_loop ()
492441 loop .set_exception_handler (handle_exception )
493442 asyncio .set_event_loop (loop )
443+ router .batch_queue = asyncio .Queue ()
444+
494445 loop .create_task (router .loop_for_fwd ())
495446 loop .run_until_complete (router .loop_for_netio_req ())
496447 return
0 commit comments