Skip to content

Commit 59fec67

Browse files
[CI][FlexAttn] Run FlexAttention backward
Signed-off-by: Whitney Tsang <[email protected]>
1 parent d478b30 commit 59fec67

File tree

2 files changed

+65
-47
lines changed

2 files changed

+65
-47
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,18 @@ jobs:
334334
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-batch16-triton-report.csv --benchmark flex-attn-causal-batch16 --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
335335
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-batch16-torch-report.csv --benchmark flex-attn-causal-batch16 --compiler torch --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Torch-TFlops --hbm_col "Torch-GB/s" --tag $TAG
336336
337+
- name: Run Triton FlexAttention Causal Mask bwd kernel benchmark
338+
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_bwd_benchmark_causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_bwd_benchmark_causal_mask.py') }}
339+
run: |
340+
export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH
341+
cd benchmarks/triton_kernels_benchmark
342+
FA_KERNEL_MODE='bwd' \
343+
python flex_attention_benchmark_causal_mask.py --reports $REPORTS --n_runs $N_RUNS
344+
345+
source ../../scripts/capture-hw-details.sh
346+
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-bwd-triton-report.csv --benchmark flex-attn-causal-bwd --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
347+
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-bwd-torch-report.csv --benchmark flex-attn-causal-bwd --compiler torch --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Torch-TFlops --hbm_col "Torch-GB/s" --tag $TAG
348+
337349
- name: Run Triton FlexAttention Custom Masks fwd kernel benchmark
338350
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py') }}
339351
run: |

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 53 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -86,53 +86,58 @@ def causal_mask(_, __, q_idx, kv_idx):
8686
@benchmark_suite.perf_report(
8787
benchmark_suite.Benchmark(
8888
x_names=['Z', 'H_q', 'H_kv', 'N_CTX_q', 'N_CTX_kv', 'D_HEAD_qk', 'D_HEAD_v', 'MODE'],
89-
x_vals=
90-
# Multi-head attention. H_q equals H_kv
91-
# Prefill shapes of Phi3-mini-4k-instruct
92-
[[z, 32, 32, 1024, 1024, 96, 96, fa_kernel_mode] for z in batch_sizes] +
93-
# Prefill shapes of Qwen3-4B
94-
[[z, 32, 32, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
95-
# Prefill shapes of DeepSeek-v3
96-
[[z, 128, 128, 1024, 1024, 192, 128, fa_kernel_mode] for z in batch_sizes] +
97-
# Append shapes of Phi3-mini-4k-instruct
98-
[[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, fa_kernel_mode] for z in batch_sizes] +
99-
100-
## Multi-query attention. H_kv equals 1.
101-
# Append shapes of Deepseek-v3
102-
[[z, 128, 1, 512, 1024 + 128 + 512, 576, 512, fa_kernel_mode] for z in batch_sizes] +
103-
104-
# Grouped-query attention. H_q / H_kv > 1
105-
# Prefill shapes of Llama-3.1-8B
106-
[[z, 32, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
107-
# Prefill shapes of meta-llama-Llama-3.2-3B
108-
[[z, 24, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
109-
# Prefill shapes of Deepseek-R1-Distill-Qwen-14B
110-
[[z, 40, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
111-
# Append shapes of Llama-3.1-8B
112-
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
113-
# Append shapes of meta-llama-Llama-3.2-3B
114-
[[z, 24, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
115-
# Append shapes of Qwen3-4B
116-
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
117-
118-
# FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
119-
# Decode shapes of Llama-3.1-8B
120-
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
121-
# Decode shapes of meta-llama-Llama-3.2-3B
122-
[[z, 24, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
123-
# Decode shapes of Phi3-mini-4k-instruct
124-
[
125-
# acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
126-
# ValueError: Shape element 2 must be a power of 2
127-
# [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes
128-
] +
129-
# Decode shapes of Qwen3-4B
130-
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
131-
# Decode shapes of Deepseek-R1-Distill-Qwen-14B
132-
[[z, 40, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
133-
# Decode shapes of Deepseek-v3
134-
[
135-
# [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes
89+
x_vals=[
90+
x_val for x_val in
91+
# Multi-head attention. H_q equals H_kv
92+
# Prefill shapes of Phi3-mini-4k-instruct
93+
[[z, 32, 32, 1024, 1024, 96, 96, fa_kernel_mode] for z in batch_sizes] +
94+
# Prefill shapes of Qwen3-4B
95+
[[z, 32, 32, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
96+
# Prefill shapes of DeepSeek-v3
97+
[[z, 128, 128, 1024, 1024, 192, 128, fa_kernel_mode] for z in batch_sizes] +
98+
# Append shapes of Phi3-mini-4k-instruct
99+
[[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, fa_kernel_mode] for z in batch_sizes] +
100+
101+
# Multi-query attention. H_kv equals 1.
102+
# Append shapes of Deepseek-v3
103+
([[z, 128, 1, 512, 1024 + 128 + 512, 576, 512, fa_kernel_mode]
104+
for z in batch_sizes] if fa_kernel_mode != 'bwd' else []) +
105+
106+
# Grouped-query attention. H_q / H_kv > 1
107+
# Prefill shapes of Llama-3.1-8B
108+
[[z, 32, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
109+
# Prefill shapes of meta-llama-Llama-3.2-3B
110+
[[z, 24, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
111+
# Prefill shapes of Deepseek-R1-Distill-Qwen-14B
112+
[[z, 40, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
113+
# Append shapes of Llama-3.1-8B
114+
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
115+
# Append shapes of meta-llama-Llama-3.2-3B
116+
[[z, 24, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
117+
# Append shapes of Qwen3-4B
118+
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
119+
120+
# FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
121+
# Decode shapes of Llama-3.1-8B
122+
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
123+
# Decode shapes of meta-llama-Llama-3.2-3B
124+
[[z, 24, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
125+
# Decode shapes of Phi3-mini-4k-instruct
126+
[
127+
# acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
128+
# ValueError: Shape element 2 must be a power of 2
129+
# [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes
130+
] +
131+
# Decode shapes of Qwen3-4B
132+
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
133+
# Decode shapes of Deepseek-R1-Distill-Qwen-14B
134+
[[z, 40, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
135+
# Decode shapes of Deepseek-v3
136+
[
137+
# [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes
138+
]
139+
# FIXME: Reenable when PyTorch fixes config in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/template_heuristics/triton.py#L1509
140+
if x_val[-1] == 'fwd' or x_val[5] != 128
136141
],
137142
line_arg='provider',
138143
line_vals=['triton', 'torch'],
@@ -143,6 +148,7 @@ def causal_mask(_, __, q_idx, kv_idx):
143148
args={},
144149
))
145150
def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provider):
151+
print(f'Running benchmark for Z={Z}, H_q={H_q}, H_kv={H_kv}, N_CTX_q={N_CTX_q}, N_CTX_kv={N_CTX_kv}, ')
146152
# Maximum across torch=200, triton=600
147153
do_bench = benchmark_suite.get_do_bench(n_warmup=600, n_repeat=10, quantiles=[0.5, 0.0, 1.0])
148154
if MODE not in ('fwd', 'bwd'):

0 commit comments

Comments
 (0)