@@ -373,7 +373,7 @@ def __init__(self, fd_config: FDConfig, device: str, rank: int, local_rank: int)
373
373
# Forward meta store the global meta information of the forward
374
374
self .forward_meta : ForwardMeta = None
375
375
376
- def insert_tasks_v1 (self , req_dicts : List [Request ], num_running_requests : int = None ):
376
+ def insert_tasks_v1 (self , req_dicts : List [Request ]):
377
377
"""
378
378
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
379
379
"""
@@ -403,7 +403,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
403
403
)
404
404
self .share_inputs ["stop_flags" ][idx : idx + 1 ] = False
405
405
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = prefill_start_index
406
- self .seq_lens_this_time_buffer [idx : idx + 1 ] = length
406
+ self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = length
407
407
self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = length
408
408
self .share_inputs ["step_seq_lens_decoder" ][idx : idx + 1 ] = 0
409
409
self .share_inputs ["prompt_lens" ][idx : idx + 1 ] = len (input_ids )
@@ -425,7 +425,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
425
425
logger .debug (f"Handle preempted request { request } at idx { idx } " )
426
426
self .share_inputs ["block_tables" ][idx : idx + 1 , :] = - 1
427
427
self .share_inputs ["stop_flags" ][idx : idx + 1 ] = True
428
- self .seq_lens_this_time_buffer [idx : idx + 1 ] = 0
428
+ self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = 0
429
429
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = 0
430
430
self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = 0
431
431
self .share_inputs ["is_block_step" ][idx : idx + 1 ] = False
@@ -462,9 +462,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
462
462
)
463
463
if has_prefill_task :
464
464
self .share_inputs ["not_need_stop" ][0 ] = True
465
- self .share_inputs ["seq_lens_this_time" ] = self .seq_lens_this_time_buffer [:num_running_requests ]
466
465
467
- def process_prefill_inputs (self , req_dicts : List [Request ], num_running_requests : int = None ):
466
+ def process_prefill_inputs (self , req_dicts : List [Request ]):
468
467
"""Process inputs for prefill tasks and update share_inputs buffer"""
469
468
req_len = len (req_dicts )
470
469
for i in range (req_len ):
@@ -483,7 +482,7 @@ def process_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
483
482
self .share_inputs ["penalty_score" ][idx : idx + 1 ] = request .get ("repetition_penalty" , 1.0 )
484
483
self .share_inputs ["frequency_score" ][idx : idx + 1 ] = request .get ("frequency_penalty" , 0.0 )
485
484
self .share_inputs ["presence_score" ][idx : idx + 1 ] = request .get ("presence_penalty" , 0.0 )
486
- self .seq_lens_this_time_buffer [idx : idx + 1 ] = length
485
+ self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = length
487
486
self .share_inputs ["step_seq_lens_encoder" ][idx : idx + 1 ] = length
488
487
self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = length
489
488
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = 0
@@ -527,7 +526,6 @@ def process_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
527
526
)
528
527
529
528
self .share_inputs ["not_need_stop" ][0 ] = True
530
- self .share_inputs ["seq_lens_this_time" ] = self .seq_lens_this_time_buffer [:num_running_requests ]
531
529
532
530
def _init_share_inputs (self , max_num_seqs : int ):
533
531
"""Initialize all share buffers for model inputs.
@@ -573,7 +571,7 @@ def _init_share_inputs(self, max_num_seqs: int):
573
571
self .share_inputs ["max_length" ] = paddle .full (
574
572
[max_num_seqs , 1 ], self .model_config .max_model_len , dtype = "int64"
575
573
)
576
- self .seq_lens_this_time_buffer = paddle .full (max_num_seqs , 0 , dtype = "int32" )
574
+ self .share_inputs [ "seq_lens_this_time" ] = paddle .full (max_num_seqs , 0 , dtype = "int32" )
577
575
self .share_inputs ["seq_lens_encoder" ] = paddle .full ([max_num_seqs , 1 ], 0 , dtype = "int32" )
578
576
self .share_inputs ["seq_lens_decoder" ] = paddle .full ([max_num_seqs , 1 ], 0 , dtype = "int32" )
579
577
self .share_inputs ["step_seq_lens_encoder" ] = paddle .full ([max_num_seqs , 1 ], 0 , dtype = "int32" )
@@ -815,7 +813,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int):
815
813
idx = i
816
814
self .share_inputs ["input_ids" ][idx : idx + 1 , :input_length ] = np .array ([5 ] * input_length )
817
815
self .share_inputs ["eos_token_id" ][:] = np .array ([2 ], dtype = "int64" ).reshape (- 1 , 1 )
818
- self .seq_lens_this_time_buffer [idx : idx + 1 ] = input_length
816
+ self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = input_length
819
817
self .share_inputs ["step_seq_lens_encoder" ][idx : idx + 1 ] = input_length
820
818
self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = input_length
821
819
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = 0
@@ -831,7 +829,6 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int):
831
829
self .share_inputs ["block_tables" ][idx : idx + 1 , :block_num ] = np .arange (
832
830
idx * block_num , (idx + 1 ) * block_num , 1
833
831
)
834
- self .share_inputs ["seq_lens_this_time" ] = self .seq_lens_this_time_buffer
835
832
836
833
def _dummy_run (
837
834
self ,
@@ -925,10 +922,6 @@ class at the server level, which is too granular for ModelRunner.
925
922
self .cache_config .block_size ,
926
923
self .cache_config .enc_dec_block_num ,
927
924
)
928
- if num_running_requests is not None :
929
- self .seq_lens_this_time_buffer [:num_running_requests ].copy_ (
930
- self .share_inputs ["seq_lens_this_time" ][:num_running_requests ], False
931
- )
932
925
933
926
return None
934
927
0 commit comments