We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1e9b419 commit dff8618Copy full SHA for dff8618
lightllm/common/basemodel/cuda_graph.py
@@ -258,13 +258,15 @@ def warmup_overlap(self, model):
258
)
259
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
260
b_seq_len.fill_(seq_len)
261
+ b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
262
263
micro_batch = ModelInput(
264
is_prefill=False,
265
batch_size=batch_size,
266
total_token_num=total_token_num,
267
max_len_in_batch=max_len_in_batch,
268
input_ids=input_ids,
269
+ b_mtp_index=b_mtp_index,
270
mem_indexes=mem_indexes,
271
b_req_idx=b_req_idx,
272
b_seq_len=b_seq_len,
0 commit comments