@@ -348,12 +348,10 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
348348 return new_model_input
349349
350350 def _create_padded_prefill_model_input (self , model_input : ModelInput , new_handle_token_num : int ):
351- if model_input .total_token_num - model_input .prefix_total_token_num == new_handle_token_num :
352- return model_input
353-
354351 assert model_input .total_token_num - model_input .prefix_total_token_num < new_handle_token_num
355352
356353 padded_token_num = new_handle_token_num - (model_input .total_token_num - model_input .prefix_total_token_num )
354+ assert padded_token_num > 0
357355 new_model_input = copy .copy (model_input )
358356 new_model_input .batch_size = model_input .batch_size + 1
359357 new_model_input .total_token_num += padded_token_num
@@ -405,16 +403,12 @@ def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_ba
405403
406404 return new_model_output
407405
408- def _create_unpad_prefill_model_output (self , model_output : ModelOutput , origin_handle_token_num : int ):
409- handle_token_num = model_output .logits .shape [0 ]
410- if handle_token_num == origin_handle_token_num :
411- return model_output
412-
406+ def _create_unpad_prefill_model_output (self , padded_model_output : ModelOutput , origin_handle_token_num : int ):
413407 if self .return_all_prompt_logics :
414- new_model_output = copy .copy (model_output )
408+ new_model_output = copy .copy (padded_model_output )
415409 new_model_output .logits = new_model_output .logits [0 :origin_handle_token_num ]
416410 else :
417- new_model_output = copy .copy (model_output )
411+ new_model_output = copy .copy (padded_model_output )
418412 # 移除多余的pad 的那个 req 对应的 logics
419413 new_model_output .logits = new_model_output .logits [0 :- 1 ]
420414
@@ -429,14 +423,18 @@ def _prefill(
429423 self ,
430424 model_input : ModelInput ,
431425 ):
432- handle_token_num = model_input .total_token_num - model_input .prefix_total_token_num
433- if self .prefill_graph is not None and self .prefill_graph .can_run (handle_token_num = handle_token_num ):
426+ origin_handle_token_num = model_input .total_token_num - model_input .prefix_total_token_num
427+
428+ is_padded_model_input = False
429+ if self .prefill_graph is not None and self .prefill_graph .can_run (handle_token_num = origin_handle_token_num ):
434430 finded_handle_token_num = self .prefill_graph .find_closest_graph_handle_token_num (
435- handle_token_num = handle_token_num
436- )
437- model_input = self ._create_padded_prefill_model_input (
438- model_input = model_input , new_handle_token_num = finded_handle_token_num
431+ handle_token_num = origin_handle_token_num
439432 )
433+ if finded_handle_token_num != origin_handle_token_num :
434+ is_padded_model_input = True
435+ model_input = self ._create_padded_prefill_model_input (
436+ model_input = model_input , new_handle_token_num = finded_handle_token_num
437+ )
440438
441439 infer_state = self ._create_inferstate (model_input )
442440 init_req_to_token_indexes (
@@ -453,7 +451,10 @@ def _prefill(
453451
454452 infer_state .init_some_extra_state (self , model_input .input_ids )
455453 model_output = self ._context_forward (model_input .input_ids , infer_state )
456- model_output = self ._create_unpad_prefill_model_output (model_output , origin_handle_token_num = handle_token_num )
454+ if is_padded_model_input :
455+ model_output = self ._create_unpad_prefill_model_output (
456+ model_output , origin_handle_token_num = origin_handle_token_num
457+ )
457458 model_output .prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
458459 return model_output
459460
0 commit comments