Skip to content

Commit f07da4a

Browse files
authored
Fix bug in FA ops calc
1 parent aa1c8ff commit f07da4a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python/perf-kernels/flash-attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1990,7 +1990,7 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
19901990
# Total zero elements are 1+2+...+(seqlen_q-1) = seqlen_q*(seqlen_q-1)/2
19911991
# Total non zero elements are seqlen_q*seqlen_k - (seqlen_q*(seqlen_q-1)/2)
19921992
valid_out_elements = ((seqlen_k**2 + seqlen_k) / 2) if seqlen_q > seqlen_k else \
1993-
(seqlen_q * seqlen_k - ((seqlen_q**1 - seqlen_q) / 2))
1993+
(seqlen_q * seqlen_k - ((seqlen_q**2 - seqlen_q) / 2))
19941994
flops_per_matmul += valid_out_elements * HQ * D_HEAD * 2
19951995
else:
19961996
flops_per_matmul += seqlen_q * seqlen_k * HQ * D_HEAD * 2

0 commit comments

Comments
 (0)