Skip to content

Commit 67a1c87

Browse files
committed
fix flops
1 parent 075f65d commit 67a1c87

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

test/benchmark/kernel/benchmark_fa3_decode_mtp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def print_error(a, b, name=""):
184184
max_seqlen = cache_seqlens.max().item()
185185
max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 # ?为什么对齐256
186186

187-
total_flops = s_q * (total_seqlens * 2 - batch_mtp) * h_q * (d + dv) * 2
187+
total_flops = s_q * total_seqlens * h_q * (d + dv) * 2 * mtp_size
188188

189189
q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device)
190190
block_table = torch.arange(batch_mtp * max_seqlen_pad, dtype=torch.int32, device=device).view(

0 commit comments

Comments
 (0)