Skip to content

Commit ecbfe9c

Browse files
authored
static test fix (#1089)
1 parent 77a92be commit ecbfe9c

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

test/benchmark/static_inference/model_infer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,23 @@ def prefill(
185185
b_ready_cache_len,
186186
):
187187
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
188+
b_prefill_start_loc = b_seq_len.cumsum(dim=0, dtype=torch.int32) - b_seq_len
188189
model_input = ModelInput(
189190
batch_size=batch_size,
190191
total_token_num=total_token_num,
191192
max_len_in_batch=max_len_in_batch,
193+
max_q_seq_len=max_len_in_batch,
194+
max_kv_seq_len=max_len_in_batch,
195+
max_cache_len=0,
192196
input_ids=input_ids,
193197
b_req_idx=b_req_idx,
194198
b_seq_len=b_seq_len,
195199
b_mtp_index=b_mtp_index,
196200
mem_indexes_cpu=mem_indexes,
197201
is_prefill=True,
198202
b_ready_cache_len=b_ready_cache_len, # b_ready_cache_len
203+
b_prefill_start_loc=b_prefill_start_loc,
204+
prefix_total_token_num=0, # the default kvcache len is zero.
199205
)
200206

201207
model_output = model_part.forward(model_input)
@@ -209,6 +215,8 @@ def decode(
209215
batch_size=batch_size,
210216
total_token_num=total_token_num,
211217
max_len_in_batch=max_len_in_batch,
218+
max_q_seq_len=1,
219+
max_kv_seq_len=max_len_in_batch,
212220
input_ids=input_ids,
213221
b_req_idx=b_req_idx,
214222
b_seq_len=b_seq_len,

0 commit comments

Comments
 (0)