@@ -1979,15 +1979,28 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
19791979 if causal :
19801980 # If seqlen_q != seqlen_k then the causal mask ignores computation
19811981 # depending on which seqlen is larger. Either the lower triangle, or right triangle
1982- causal_correction = seqlen_k if seqlen_q > seqlen_k else seqlen_q
1983- flops_per_matmul += (seqlen_q * seqlen_k - (causal_correction ** 2 ) / 2 ) * HQ * D_HEAD * 2
1982+ # If seqlen_q is greater than seqlen_k, the lower triangle is non zero
1983+ # where the last row has seqlen_k valid element, the second last row has
1984+ # seqlen_k - 1 valid elements and so on until one element is valid in the
1985+ # seqlen_q - seqlen_k row, hence total valid elements are 1+2+...+seqlen_k
1986+ # which is seqlen_k*(seqlen_k+1)/2
1987+ # If seqlen_q is less than seqlen_k, then we count the zero elements
1988+ # the first row has seqlen_q-1 zero elements, the second row has seqlen_q-2
1989+ # zero elements and so on until the second last row has 1 zero element
1990+ # Total zero elements are 1+2+...+(seqlen_q-1) = seqlen_q*(seqlen_q-1)/2
1991+ # Total non zero elements are seqlen_q*seqlen_k - (seqlen_q*(seqlen_q-1)/2)
1992+ valid_out_elements = ((seqlen_k ** 2 + seqlen_k ) / 2 ) if seqlen_q > seqlen_k else \
1993+ (seqlen_q * seqlen_k - ((seqlen_q ** 1 - seqlen_q ) / 2 ))
1994+ flops_per_matmul += valid_out_elements * HQ * D_HEAD * 2
19841995 else :
19851996 flops_per_matmul += seqlen_q * seqlen_k * HQ * D_HEAD * 2
19861997 else :
19871998 q , k , v , input_metadata = input_helper (BATCH , HQ , HK , N_CTX_Q , N_CTX_K , D_HEAD , dtype , args .layout )
19881999 if causal :
1989- causal_correction = N_CTX_K if N_CTX_Q > N_CTX_K else N_CTX_Q
1990- flops_per_matmul = 2.0 * BATCH * HQ * (N_CTX_Q * N_CTX_K - (causal_correction ** 2 ) / 2 ) * D_HEAD
2000+ # Same calculation as if varlen/if causal above
2001+ valid_out_elements = ((N_CTX_K ** 2 + N_CTX_K ) / 2 ) if N_CTX_Q > N_CTX_K else \
2002+ (N_CTX_Q * N_CTX_K - ((N_CTX_Q ** 2 - N_CTX_Q ) / 2 ))
2003+ flops_per_matmul = 2.0 * BATCH * HQ * valid_out_elements * D_HEAD
19912004 else :
19922005 flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD
19932006 if causal :
0 commit comments