@@ -352,16 +352,16 @@ def _prefill(
352352 model_input : ModelInput ,
353353 ):
354354 infer_state = self ._create_inferstate (model_input )
355- infer_state .init_some_extra_state (self , model_input .input_ids )
356355 init_req_to_token_indexes (
357356 req_to_token_indexs = self .req_manager .req_to_token_indexs ,
358357 b_req_idx = infer_state .b_req_idx ,
359358 b_seq_len = infer_state .b_seq_len ,
360359 b_ready_cache_len = infer_state .b_ready_cache_len ,
361- b_start_loc = infer_state . b_start_loc ,
360+ b_start_loc = model_input . b_prefill_start_loc ,
362361 alloc_mem_index = infer_state .mem_index ,
363362 max_q_seq_len = infer_state .max_q_seq_len ,
364363 )
364+ infer_state .init_some_extra_state (self , model_input .input_ids )
365365 return self ._context_forward (model_input .input_ids , infer_state )
366366
367367 def _decode (
@@ -491,28 +491,28 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
491491 input_ids0 , input_ids1 = model_input0 .input_ids , model_input1 .input_ids
492492
493493 infer_state0 = self ._create_inferstate (model_input0 , 0 )
494- infer_state0 .init_some_extra_state (self , input_ids0 )
495494 init_req_to_token_indexes (
496495 req_to_token_indexs = self .req_manager .req_to_token_indexs ,
497496 b_req_idx = infer_state0 .b_req_idx ,
498497 b_seq_len = infer_state0 .b_seq_len ,
499498 b_ready_cache_len = infer_state0 .b_ready_cache_len ,
500- b_start_loc = infer_state0 . b_start_loc ,
499+ b_start_loc = model_input0 . b_prefill_start_loc ,
501500 alloc_mem_index = infer_state0 .mem_index ,
502501 max_q_seq_len = infer_state0 .max_q_seq_len ,
503502 )
503+ infer_state0 .init_some_extra_state (self , input_ids0 )
504504
505505 infer_state1 = self ._create_inferstate (model_input1 , 1 )
506- infer_state1 .init_some_extra_state (self , input_ids1 )
507506 init_req_to_token_indexes (
508507 req_to_token_indexs = self .req_manager .req_to_token_indexs ,
509508 b_req_idx = infer_state1 .b_req_idx ,
510509 b_seq_len = infer_state1 .b_seq_len ,
511510 b_ready_cache_len = infer_state1 .b_ready_cache_len ,
512- b_start_loc = infer_state1 . b_start_loc ,
511+ b_start_loc = model_input1 . b_prefill_start_loc ,
513512 alloc_mem_index = infer_state1 .mem_index ,
514513 max_q_seq_len = infer_state1 .max_q_seq_len ,
515514 )
515+ infer_state1 .init_some_extra_state (self , input_ids1 )
516516
517517 model_output0 , model_output1 = self ._overlap_tpsp_context_forward (
518518 input_ids0 , infer_state0 , input_ids1 = input_ids1 , infer_state1 = infer_state1
@@ -713,6 +713,7 @@ def _check_max_len_infer(self):
713713 b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda" )
714714 b_seq_len [:] = self .batch_max_tokens
715715 b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
716+ b_prefill_start_loc = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
716717 total_token_num = self .batch_max_tokens
717718 b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
718719 model_input = ModelInput (
@@ -730,6 +731,7 @@ def _check_max_len_infer(self):
730731 b_mtp_index = b_mtp_index ,
731732 is_prefill = True ,
732733 b_ready_cache_len = b_ready_cache_len ,
734+ b_prefill_start_loc = b_prefill_start_loc ,
733735 )
734736 model_output = self .forward (
735737 model_input ,
@@ -787,6 +789,7 @@ def _autotune_warmup(self):
787789 b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda" )
788790 b_seq_len [:] = input_len
789791 b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
792+ b_prefill_start_loc = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
790793 total_token_num = input_len
791794 b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
792795 model_input = ModelInput (
@@ -804,6 +807,7 @@ def _autotune_warmup(self):
804807 b_mtp_index = b_mtp_index ,
805808 is_prefill = True ,
806809 b_ready_cache_len = b_ready_cache_len ,
810+ b_prefill_start_loc = b_prefill_start_loc ,
807811 multimodal_params = [],
808812 ** self ._gen_special_model_input (total_token_num ),
809813 )
@@ -847,6 +851,8 @@ def _init_padded_req(self):
847851 )
848852 b_seq_len = torch .ones (batch_size , dtype = torch .int32 , device = "cuda" )
849853 b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
854+ b_q_seq_len = b_seq_len - b_ready_cache_len
855+ b_prefill_start_loc = b_q_seq_len .cumsum (dim = 0 , dtype = torch .int32 ) - b_q_seq_len
850856 total_token_num = prefill_input_len * batch_size
851857 b_mtp_index = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
852858 model_input = ModelInput (
@@ -863,6 +869,7 @@ def _init_padded_req(self):
863869 b_mtp_index = b_mtp_index ,
864870 b_seq_len = b_seq_len ,
865871 b_ready_cache_len = b_ready_cache_len ,
872+ b_prefill_start_loc = b_prefill_start_loc ,
866873 is_prefill = True ,
867874 multimodal_params = [],
868875 ** self ._gen_special_model_input (total_token_num ),
0 commit comments