Skip to content

Commit eb3bf93

Browse files
authored
[BENCHMARK] Update the number of flops of the flex attn of causal which is align to the tutorial. (#4805)
Only the lower triangle is computed when the causal=True for attention. --------- Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 99bc7f5 commit eb3bf93

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
184184
else:
185185
raise NotImplementedError(f'Unsupported provider {provider}')
186186

187-
qk_flops = H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * 2 # mul + add
188-
pv_flops = H_q * N_CTX_q * D_HEAD_v * N_CTX_kv * 2 # mul + add
187+
qk_flops = H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk # mul + add, causal=True. Only the lower triangle is computed.
188+
pv_flops = H_q * N_CTX_q * D_HEAD_v * N_CTX_kv # mul + add, causal=True. Only the lower triangle is computed.
189189
tflops = lambda mean: Z * (qk_flops + pv_flops) * (1e-12) / (mean * 1e-3)
190190

191191
q_elems = H_q * N_CTX_q * D_HEAD_qk

0 commit comments

Comments
 (0)