@@ -152,9 +152,11 @@ def _init_logits_processor(self, request):
152
152
schemata_key ,
153
153
)
154
154
155
- def insert_prefill_inputs (self , req_dicts : List [Request ]):
155
+ def insert_prefill_inputs (self , req_dicts : List [Request ], num_running_requests : int = None ):
156
156
"""
157
157
Process inputs for prefill tasks and insert it to share_inputs buffer
158
+ req_dict: A list of Request dict
159
+ num_running_requests: batch_size
158
160
"""
159
161
160
162
if req_dicts [- 1 ].disaggregate_info is not None and req_dicts [- 1 ].disaggregate_info ["role" ] == "prefill" :
@@ -193,7 +195,7 @@ def get_attr_from_request(request, attr, default_value=None):
193
195
self .share_inputs ["prompt_ids" ][idx : idx + 1 , :length ] = np .array (request .prompt_token_ids )
194
196
self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = 0
195
197
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = length
196
- self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = 1
198
+ self .seq_lens_this_time_buffer [idx : idx + 1 ] = 1
197
199
self .share_inputs ["step_seq_lens_encoder" ][idx : idx + 1 ] = 0
198
200
self .share_inputs ["step_seq_lens_decoder" ][idx : idx + 1 ] = length
199
201
self .share_inputs ["prompt_lens" ][idx : idx + 1 ] = length
@@ -205,7 +207,7 @@ def get_attr_from_request(request, attr, default_value=None):
205
207
request .draft_token_ids [0 :num_prefill_send_token ],
206
208
dtype = "int64" ,
207
209
)
208
- self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = num_prefill_send_token
210
+ self .seq_lens_this_time_buffer [idx : idx + 1 ] = num_prefill_send_token
209
211
else :
210
212
self .share_inputs ["pre_ids" ][idx : idx + 1 ] = - 1
211
213
self .share_inputs ["step_idx" ][idx : idx + 1 ] = 0
@@ -222,14 +224,14 @@ def get_attr_from_request(request, attr, default_value=None):
222
224
)
223
225
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = request .get ("seq_lens_decoder" , 0 )
224
226
self .share_inputs ["step_seq_lens_decoder" ][idx : idx + 1 ] = request .get ("seq_lens_decoder" , 0 )
225
- self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = token_chunk_size
227
+ self .seq_lens_this_time_buffer [idx : idx + 1 ] = token_chunk_size
226
228
self .share_inputs ["step_seq_lens_encoder" ][idx : idx + 1 ] = token_chunk_size
227
229
self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = token_chunk_size
228
230
self .share_inputs ["prompt_lens" ][idx : idx + 1 ] = token_chunk_size
229
231
else :
230
232
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = request .get ("seq_lens_decoder" , 0 )
231
233
self .share_inputs ["step_seq_lens_decoder" ][idx : idx + 1 ] = request .get ("seq_lens_decoder" , 0 )
232
- self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = length
234
+ self .seq_lens_this_time_buffer [idx : idx + 1 ] = length
233
235
self .share_inputs ["step_seq_lens_encoder" ][idx : idx + 1 ] = length
234
236
self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = length
235
237
self .share_inputs ["prompt_lens" ][idx : idx + 1 ] = length
@@ -295,6 +297,7 @@ def get_attr_from_request(request, attr, default_value=None):
295
297
296
298
if self .speculative_method in ["mtp" ]:
297
299
self .proposer .insert_prefill_inputs (req_dicts )
300
+ self .share_inputs ["seq_lens_this_time" ] = self .seq_lens_this_time_buffer [:num_running_requests ]
298
301
299
302
def _dummy_prefill_inputs (self , num_tokens : int , batch_size : int , expected_decode_len : int ):
300
303
"""Set dummy prefill inputs to share_inputs"""
@@ -313,7 +316,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decod
313
316
self .share_inputs ["input_ids" ][idx : idx + 1 , :input_length ] = np .array ([5 ] * input_length )
314
317
self .share_inputs ["prompt_ids" ][idx : idx + 1 , :input_length ] = np .array ([5 ] * input_length )
315
318
self .share_inputs ["eos_token_id" ][:] = np .array ([2 ], dtype = "int64" ).reshape (- 1 , 1 )
316
- self .share_inputs [ "seq_lens_this_time" ] [idx : idx + 1 ] = input_length
319
+ self .seq_lens_this_time_buffer [idx : idx + 1 ] = input_length
317
320
self .share_inputs ["step_seq_lens_encoder" ][idx : idx + 1 ] = input_length
318
321
self .share_inputs ["seq_lens_encoder" ][idx : idx + 1 ] = input_length
319
322
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = 0
@@ -331,6 +334,7 @@ def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decod
331
334
self .share_inputs ["block_tables" ][idx : idx + 1 , :block_num ] = np .arange (
332
335
idx * block_num , (idx + 1 ) * block_num , 1
333
336
)
337
+ self .share_inputs ["seq_lens_this_time" ] = self .seq_lens_this_time_buffer
334
338
335
339
def _init_share_inputs (self , max_num_seqs : int ):
336
340
"""
@@ -381,7 +385,7 @@ def _init_share_inputs(self, max_num_seqs: int):
381
385
self .share_inputs ["max_length" ] = paddle .full (
382
386
[max_num_seqs , 1 ], self .model_config .max_model_len , dtype = "int64"
383
387
)
384
- self .share_inputs [ "seq_lens_this_time" ] = paddle .full (max_num_seqs , 0 , dtype = "int32" )
388
+ self .seq_lens_this_time_buffer = paddle .full (max_num_seqs , 0 , dtype = "int32" )
385
389
self .share_inputs ["seq_lens_encoder" ] = paddle .full ([max_num_seqs , 1 ], 0 , dtype = "int32" )
386
390
self .share_inputs ["seq_lens_decoder" ] = paddle .full ([max_num_seqs , 1 ], 0 , dtype = "int32" )
387
391
self .share_inputs ["step_seq_lens_encoder" ] = paddle .full ([max_num_seqs , 1 ], 0 , dtype = "int32" )
@@ -923,13 +927,15 @@ def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None):
923
927
def execute_model (
924
928
self ,
925
929
model_forward_batch : Optional [List [Request ]] = None ,
930
+ num_running_requests : int = None ,
926
931
) -> Optional [ModelRunnerOutput ]:
927
932
"""
928
933
The Entrance of model execute.
929
934
Args:
930
935
model_forward_batch: 'Request' contains information related to prompt and is an abstract
931
936
class at the server level, which is too granular for ModelRunner.
932
937
We plan to replace it with 'ModelForwardBatch'.
938
+ num_running_requests: batch_size
933
939
intermediate_tensors:
934
940
"""
935
941
# If `not_need_stop`` is False, it means the current worker is in an idle state.
@@ -1055,6 +1061,9 @@ class at the server level, which is too granular for ModelRunner.
1055
1061
1056
1062
self ._update_chunked_prefill (model_forward_batch )
1057
1063
self ._add_cache (model_forward_batch )
1064
+ self .seq_lens_this_time_buffer [:num_running_requests ].copy_ (
1065
+ self .share_inputs ["seq_lens_this_time" ][:num_running_requests ], False
1066
+ )
1058
1067
return None
1059
1068
1060
1069
def _add_cache (self , model_forward_batch ) -> None :
0 commit comments