@@ -343,16 +343,16 @@ def _prefill(
343343 model_input : ModelInput ,
344344 ):
345345 infer_state = self ._create_inferstate (model_input )
346- infer_state .init_some_extra_state (self , model_input .input_ids )
347346 init_req_to_token_indexes (
348347 req_to_token_indexs = self .req_manager .req_to_token_indexs ,
349348 b_req_idx = infer_state .b_req_idx ,
350349 b_seq_len = infer_state .b_seq_len ,
351350 b_ready_cache_len = infer_state .b_ready_cache_len ,
352- b_start_loc = infer_state . b_start_loc ,
351+ b_start_loc = model_input . b_prefill_start_loc ,
353352 alloc_mem_index = infer_state .mem_index ,
354353 max_q_seq_len = infer_state .max_q_seq_len ,
355354 )
355+ infer_state .init_some_extra_state (self , model_input .input_ids )
356356 return self ._context_forward (model_input .input_ids , infer_state )
357357
358358 def _decode (
@@ -482,28 +482,28 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
482482 input_ids0 , input_ids1 = model_input0 .input_ids , model_input1 .input_ids
483483
484484 infer_state0 = self ._create_inferstate (model_input0 , 0 )
485- infer_state0 .init_some_extra_state (self , input_ids0 )
486485 init_req_to_token_indexes (
487486 req_to_token_indexs = self .req_manager .req_to_token_indexs ,
488487 b_req_idx = infer_state0 .b_req_idx ,
489488 b_seq_len = infer_state0 .b_seq_len ,
490489 b_ready_cache_len = infer_state0 .b_ready_cache_len ,
491- b_start_loc = infer_state0 . b_start_loc ,
490+ b_start_loc = model_input0 . b_prefill_start_loc ,
492491 alloc_mem_index = infer_state0 .mem_index ,
493492 max_q_seq_len = infer_state0 .max_q_seq_len ,
494493 )
494+ infer_state0 .init_some_extra_state (self , input_ids0 )
495495
496496 infer_state1 = self ._create_inferstate (model_input1 , 1 )
497- infer_state1 .init_some_extra_state (self , input_ids1 )
498497 init_req_to_token_indexes (
499498 req_to_token_indexs = self .req_manager .req_to_token_indexs ,
500499 b_req_idx = infer_state1 .b_req_idx ,
501500 b_seq_len = infer_state1 .b_seq_len ,
502501 b_ready_cache_len = infer_state1 .b_ready_cache_len ,
503- b_start_loc = infer_state1 . b_start_loc ,
502+ b_start_loc = model_input1 . b_prefill_start_loc ,
504503 alloc_mem_index = infer_state1 .mem_index ,
505504 max_q_seq_len = infer_state1 .max_q_seq_len ,
506505 )
506+ infer_state1 .init_some_extra_state (self , input_ids1 )
507507
508508 model_output0 , model_output1 = self ._overlap_tpsp_context_forward (
509509 input_ids0 , infer_state0 , input_ids1 = input_ids1 , infer_state1 = infer_state1
@@ -704,6 +704,7 @@ def _check_max_len_infer(self):
704704 b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda" )
705705 b_seq_len [:] = self .batch_max_tokens
706706 b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
707+ b_prefill_start_loc = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
707708 total_token_num = self .batch_max_tokens
708709 b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
709710 model_input = ModelInput (
@@ -721,6 +722,7 @@ def _check_max_len_infer(self):
721722 b_mtp_index = b_mtp_index ,
722723 is_prefill = True ,
723724 b_ready_cache_len = b_ready_cache_len ,
725+ b_prefill_start_loc = b_prefill_start_loc ,
724726 )
725727 model_output = self .forward (
726728 model_input ,
@@ -778,6 +780,7 @@ def _autotune_warmup(self):
778780 b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda" )
779781 b_seq_len [:] = input_len
780782 b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
783+ b_prefill_start_loc = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
781784 total_token_num = input_len
782785 b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
783786 model_input = ModelInput (
@@ -795,6 +798,7 @@ def _autotune_warmup(self):
795798 b_mtp_index = b_mtp_index ,
796799 is_prefill = True ,
797800 b_ready_cache_len = b_ready_cache_len ,
801+ b_prefill_start_loc = b_prefill_start_loc ,
798802 multimodal_params = [],
799803 ** self ._gen_special_model_input (total_token_num ),
800804 )
@@ -838,6 +842,8 @@ def _init_padded_req(self):
838842 )
839843 b_seq_len = torch .ones (batch_size , dtype = torch .int32 , device = "cuda" )
840844 b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
845+ b_q_seq_len = b_seq_len - b_ready_cache_len
846+ b_prefill_start_loc = b_q_seq_len .cumsum (dim = 0 , dtype = torch .int32 ) - b_q_seq_len
841847 total_token_num = prefill_input_len * batch_size
842848 b_mtp_index = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
843849 model_input = ModelInput (
@@ -854,6 +860,7 @@ def _init_padded_req(self):
854860 b_mtp_index = b_mtp_index ,
855861 b_seq_len = b_seq_len ,
856862 b_ready_cache_len = b_ready_cache_len ,
863+ b_prefill_start_loc = b_prefill_start_loc ,
857864 is_prefill = True ,
858865 multimodal_params = [],
859866 ** self ._gen_special_model_input (total_token_num ),
0 commit comments