File tree Expand file tree Collapse file tree 1 file changed +8
-0
lines changed
test/benchmark/static_inference Expand file tree Collapse file tree 1 file changed +8
-0
lines changed Original file line number Diff line number Diff line change @@ -185,17 +185,23 @@ def prefill(
185185 b_ready_cache_len ,
186186):
187187 b_mtp_index = torch .zeros (batch_size , dtype = torch .int32 , device = "cpu" )
188+ b_prefill_start_loc = b_seq_len .cumsum (dim = 0 , dtype = torch .int32 ) - b_seq_len
188189 model_input = ModelInput (
189190 batch_size = batch_size ,
190191 total_token_num = total_token_num ,
191192 max_len_in_batch = max_len_in_batch ,
193+ max_q_seq_len = max_len_in_batch ,
194+ max_kv_seq_len = max_len_in_batch ,
195+ max_cache_len = 0 ,
192196 input_ids = input_ids ,
193197 b_req_idx = b_req_idx ,
194198 b_seq_len = b_seq_len ,
195199 b_mtp_index = b_mtp_index ,
196200 mem_indexes_cpu = mem_indexes ,
197201 is_prefill = True ,
198202 b_ready_cache_len = b_ready_cache_len , # b_ready_cache_len
203+ b_prefill_start_loc = b_prefill_start_loc ,
204+ prefix_total_token_num = 0 , # the default kvcache len is zero.
199205 )
200206
201207 model_output = model_part .forward (model_input )
@@ -209,6 +215,8 @@ def decode(
209215 batch_size = batch_size ,
210216 total_token_num = total_token_num ,
211217 max_len_in_batch = max_len_in_batch ,
218+ max_q_seq_len = 1 ,
219+ max_kv_seq_len = max_len_in_batch ,
212220 input_ids = input_ids ,
213221 b_req_idx = b_req_idx ,
214222 b_seq_len = b_seq_len ,
You can’t perform that action at this time.
0 commit comments