@@ -152,9 +152,11 @@ def _init_logits_processor(self, request):
152
152
schemata_key ,
153
153
)
154
154
155
- def insert_prefill_inputs (self , req_dicts : List [Request ]):
155
+ def insert_prefill_inputs (self , req_dicts : List [Request ], num_running_requests : int = None ):
156
156
"""
157
157
Process inputs for prefill tasks and insert it to share_inputs buffer
158
+ req_dict: A list of Request dict
159
+ num_running_requests: batch_size
158
160
"""
159
161
160
162
if req_dicts [- 1 ].disaggregate_info is not None and req_dicts [- 1 ].disaggregate_info ["role" ] == "prefill" :
@@ -193,7 +195,7 @@ def get_attr_from_request(request, attr, default_value=None):
193
195
self .share_inputs ["prompt_ids" ][idx : idx + 1 , :length ] = np .array (request .prompt_token_ids )
194
196
self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = 0
195
197
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = length
196
- self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = 1
198
+ self .seq_lens_this_time_buffer [idx : idx + 1 ] = 1
197
199
self .share_inputs ["step_seq_lens_encoder" ][idx : idx + 1 ] = 0
198
200
self .share_inputs ["step_seq_lens_decoder" ][idx : idx + 1 ] = length
199
201
self .share_inputs ["prompt_lens" ][idx : idx + 1 ] = length
@@ -205,7 +207,7 @@ def get_attr_from_request(request, attr, default_value=None):
205
207
request .draft_token_ids [0 :num_prefill_send_token ],
206
208
dtype = "int64" ,
207
209
)
208
- self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = num_prefill_send_token
210
+ self .seq_lens_this_time_buffer [idx : idx + 1 ] = num_prefill_send_token
209
211
else :
210
212
self .share_inputs ["pre_ids" ][idx : idx + 1 ] = - 1
211
213
self .share_inputs ["step_idx" ][idx : idx + 1 ] = 0
@@ -222,14 +224,14 @@ def get_attr_from_request(request, attr, default_value=None):
222
224
)
223
225
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = request .get ("seq_lens_decoder" , 0 )
224
226
self .share_inputs ["step_seq_lens_decoder" ][idx : idx + 1 ] = request .get ("seq_lens_decoder" , 0 )
225
- self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = token_chunk_size
227
+ self .seq_lens_this_time_buffer [idx : idx + 1 ] = token_chunk_size
226
228
self .share_inputs ["step_seq_lens_encoder" ][idx : idx + 1 ] = token_chunk_size
227
229
self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = token_chunk_size
228
230
self .share_inputs ["prompt_lens" ][idx : idx + 1 ] = token_chunk_size
229
231
else :
230
232
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = request .get ("seq_lens_decoder" , 0 )
231
233
self .share_inputs ["step_seq_lens_decoder" ][idx : idx + 1 ] = request .get ("seq_lens_decoder" , 0 )
232
- self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = length
234
+ self .seq_lens_this_time_buffer [idx : idx + 1 ] = length
233
235
self .share_inputs ["step_seq_lens_encoder" ][idx : idx + 1 ] = length
234
236
self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = length
235
237
self .share_inputs ["prompt_lens" ][idx : idx + 1 ] = length
@@ -293,6 +295,7 @@ def get_attr_from_request(request, attr, default_value=None):
293
295
294
296
if self .speculative_method in ["mtp" ]:
295
297
self .proposer .insert_prefill_inputs (req_dicts )
298
+ self .share_inputs ["seq_lens_this_time" ] = self .seq_lens_this_time_buffer [:num_running_requests ]
296
299
297
300
def _dummy_prefill_inputs (self , num_tokens : int , batch_size : int , expected_decode_len : int ):
298
301
"""Set dummy prefill inputs to share_inputs"""
@@ -311,7 +314,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decod
311
314
self .share_inputs ["input_ids" ][idx : idx + 1 , :input_length ] = np .array ([5 ] * input_length )
312
315
self .share_inputs ["prompt_ids" ][idx : idx + 1 , :input_length ] = np .array ([5 ] * input_length )
313
316
self .share_inputs ["eos_token_id" ][:] = np .array ([2 ], dtype = "int64" ).reshape (- 1 , 1 )
314
- self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = input_length
317
+ self .seq_lens_this_time_buffer [idx : idx + 1 ] = input_length
315
318
self .share_inputs ["step_seq_lens_encoder" ][idx : idx + 1 ] = input_length
316
319
self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = input_length
317
320
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = 0
@@ -329,6 +332,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decod
329
332
self .share_inputs ["block_tables" ][idx : idx + 1 , :block_num ] = np .arange (
330
333
idx * block_num , (idx + 1 ) * block_num , 1
331
334
)
335
+ self .share_inputs ["seq_lens_this_time" ] = self .seq_lens_this_time_buffer
332
336
333
337
def _init_share_inputs (self , max_num_seqs : int ):
334
338
"""
@@ -379,7 +383,7 @@ def _init_share_inputs(self, max_num_seqs: int):
379
383
self .share_inputs ["max_length" ] = paddle .full (
380
384
[max_num_seqs , 1 ], self .model_config .max_model_len , dtype = "int64"
381
385
)
382
- self .share_inputs [ "seq_lens_this_time" ] = paddle .full (max_num_seqs , 0 , dtype = "int32" )
386
+ self .seq_lens_this_time_buffer = paddle .full (max_num_seqs , 0 , dtype = "int32" )
383
387
self .share_inputs ["seq_lens_encoder" ] = paddle .full ([max_num_seqs , 1 ], 0 , dtype = "int32" )
384
388
self .share_inputs ["seq_lens_decoder" ] = paddle .full ([max_num_seqs , 1 ], 0 , dtype = "int32" )
385
389
self .share_inputs ["step_seq_lens_encoder" ] = paddle .full ([max_num_seqs , 1 ], 0 , dtype = "int32" )
@@ -921,13 +925,15 @@ def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
921
925
def execute_model (
922
926
self ,
923
927
model_forward_batch : Optional [List [Request ]] = None ,
928
+ num_running_requests : int = None ,
924
929
) -> Optional [ModelRunnerOutput ]:
925
930
"""
926
931
The Entrance of model execute.
927
932
Args:
928
933
model_forward_batch: 'Request' contains information related to prompt and is an abstract
929
934
class at the server level, which is too granular for ModelRunner.
930
935
We plan to replace it with 'ModelForwardBatch'.
936
+ num_running_requests: batch_size
931
937
intermediate_tensors:
932
938
"""
933
939
# If `not_need_stop`` is False, it means the current worker is in an idle state.
@@ -1053,6 +1059,9 @@ class at the server level, which is too granular for ModelRunner.
1053
1059
1054
1060
self ._update_chunked_prefill (model_forward_batch )
1055
1061
self ._add_cache (model_forward_batch )
1062
+ self .seq_lens_this_time_buffer [:num_running_requests ].copy_ (
1063
+ self .share_inputs ["seq_lens_this_time" ][:num_running_requests ], False
1064
+ )
1056
1065
return None
1057
1066
1058
1067
def _add_cache (self , model_forward_batch ) -> None :
0 commit comments