@@ -480,7 +480,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
480480 model_input0 .max_len_in_batch ,
481481 infer_state0 .mem_index ,
482482 )
483- infer_state0 .init_some_extra_state (self , input_ids0 )
483+ infer_state0 .init_some_extra_state (self , model_input0 )
484484
485485 infer_state1 = self ._create_inferstate (model_input1 , 1 )
486486 init_req_to_token_indexes (
@@ -491,7 +491,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
491491 model_input1 .max_len_in_batch ,
492492 infer_state1 .mem_index ,
493493 )
494- infer_state1 .init_some_extra_state (self , input_ids1 )
494+ infer_state1 .init_some_extra_state (self , model_input1 )
495495
496496 model_output0 , model_output1 = self ._overlap_tpsp_context_forward (
497497 input_ids0 , infer_state0 , input_ids1 = input_ids1 , infer_state1 = infer_state1
@@ -576,15 +576,15 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
576576 infer_state0 .b_seq_len ,
577577 infer_state0 .mem_index ,
578578 )
579- infer_state0 .init_some_extra_state (self , model_input0 . input_ids )
579+ infer_state0 .init_some_extra_state (self , model_input0 )
580580 infer_state1 = self ._create_inferstate (model_input1 , 1 )
581581 copy_kv_index_to_req (
582582 self .req_manager .req_to_token_indexs ,
583583 infer_state1 .b_req_idx ,
584584 infer_state1 .b_seq_len ,
585585 infer_state1 .mem_index ,
586586 )
587- infer_state1 .init_some_extra_state (self , model_input1 . input_ids )
587+ infer_state1 .init_some_extra_state (self , model_input1 )
588588
589589 model_output0 , model_output1 = self ._overlap_tpsp_token_forward (
590590 model_input0 .input_ids , infer_state0 , input_ids1 = model_input1 .input_ids , infer_state1 = infer_state1
0 commit comments