Skip to content

Commit aa1c8ff

Browse files
authored
Fix flash attention ops calculation (#758)
1 parent 48455d1 commit aa1c8ff

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

python/perf-kernels/flash-attention.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1979,15 +1979,28 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
19791979
if causal:
19801980
# If seqlen_q != seqlen_k then the causal mask ignores computation
19811981
# depending on which seqlen is larger. Either the lower triangle, or right triangle
1982-
causal_correction = seqlen_k if seqlen_q > seqlen_k else seqlen_q
1983-
flops_per_matmul += (seqlen_q * seqlen_k - (causal_correction**2) / 2) * HQ * D_HEAD * 2
1982+
# If seqlen_q is greater than seqlen_k, the lower triangle is non zero
1983+
# where the last row has seqlen_k valid element, the second last row has
1984+
# seqlen_k - 1 valid elements and so on until one element is valid in the
1985+
# seqlen_q - seqlen_k row, hence total valid elements are 1+2+...+seqlen_k
1986+
# which is seqlen_k*(seqlen_k+1)/2
1987+
# If seqlen_q is less than seqlen_k, then we count the zero elements
1988+
# the first row has seqlen_q-1 zero elements, the second row has seqlen_q-2
1989+
# zero elements and so on until the second last row has 1 zero element
1990+
# Total zero elements are 1+2+...+(seqlen_q-1) = seqlen_q*(seqlen_q-1)/2
1991+
# Total non zero elements are seqlen_q*seqlen_k - (seqlen_q*(seqlen_q-1)/2)
1992+
valid_out_elements = ((seqlen_k**2 + seqlen_k) / 2) if seqlen_q > seqlen_k else \
1993+
(seqlen_q * seqlen_k - ((seqlen_q**1 - seqlen_q) / 2))
1994+
flops_per_matmul += valid_out_elements * HQ * D_HEAD * 2
19841995
else:
19851996
flops_per_matmul += seqlen_q * seqlen_k * HQ * D_HEAD * 2
19861997
else:
19871998
q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, args.layout)
19881999
if causal:
1989-
causal_correction = N_CTX_K if N_CTX_Q > N_CTX_K else N_CTX_Q
1990-
flops_per_matmul = 2.0 * BATCH * HQ * (N_CTX_Q * N_CTX_K - (causal_correction**2) / 2) * D_HEAD
2000+
# Same calculation as if varlen/if causal above
2001+
valid_out_elements = ((N_CTX_K**2 + N_CTX_K) / 2) if N_CTX_Q > N_CTX_K else \
2002+
(N_CTX_Q * N_CTX_K - ((N_CTX_Q**2 - N_CTX_Q) / 2))
2003+
flops_per_matmul = 2.0 * BATCH * HQ * valid_out_elements * D_HEAD
19912004
else:
19922005
flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD
19932006
if causal:

0 commit comments

Comments
 (0)