Skip to content

Commit f9984d2

Browse files
Enable Flex Attention bwd shapes with D_HEAD_qk==128
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 4583119 commit f9984d2

File tree

1 file changed

+48
-52
lines changed

1 file changed

+48
-52
lines changed

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -86,58 +86,54 @@ def causal_mask(_, __, q_idx, kv_idx):
8686
@benchmark_suite.perf_report(
8787
benchmark_suite.Benchmark(
8888
x_names=['Z', 'H_q', 'H_kv', 'N_CTX_q', 'N_CTX_kv', 'D_HEAD_qk', 'D_HEAD_v', 'MODE'],
89-
x_vals=[
90-
x_val for x_val in
91-
# Multi-head attention. H_q equals H_kv
92-
# Prefill shapes of Phi3-mini-4k-instruct
93-
[[z, 32, 32, 1024, 1024, 96, 96, fa_kernel_mode] for z in batch_sizes] +
94-
# Prefill shapes of Qwen3-4B
95-
[[z, 32, 32, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
96-
# Prefill shapes of DeepSeek-v3
97-
[[z, 128, 128, 1024, 1024, 192, 128, fa_kernel_mode] for z in batch_sizes] +
98-
# Append shapes of Phi3-mini-4k-instruct
99-
[[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, fa_kernel_mode] for z in batch_sizes] +
100-
101-
# Multi-query attention. H_kv equals 1.
102-
# Append shapes of Deepseek-v3
103-
([[z, 128, 1, 512, 1024 + 128 + 512, 576, 512, fa_kernel_mode]
104-
for z in batch_sizes] if fa_kernel_mode != 'bwd' else []) +
105-
106-
# Grouped-query attention. H_q / H_kv > 1
107-
# Prefill shapes of Llama-3.1-8B
108-
[[z, 32, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
109-
# Prefill shapes of meta-llama-Llama-3.2-3B
110-
[[z, 24, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
111-
# Prefill shapes of Deepseek-R1-Distill-Qwen-14B
112-
[[z, 40, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
113-
# Append shapes of Llama-3.1-8B
114-
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
115-
# Append shapes of meta-llama-Llama-3.2-3B
116-
[[z, 24, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
117-
# Append shapes of Qwen3-4B
118-
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
119-
120-
# FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
121-
# Decode shapes of Llama-3.1-8B
122-
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
123-
# Decode shapes of meta-llama-Llama-3.2-3B
124-
[[z, 24, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
125-
# Decode shapes of Phi3-mini-4k-instruct
126-
[
127-
# acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
128-
# ValueError: Shape element 2 must be a power of 2
129-
# [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes
130-
] +
131-
# Decode shapes of Qwen3-4B
132-
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
133-
# Decode shapes of Deepseek-R1-Distill-Qwen-14B
134-
[[z, 40, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
135-
# Decode shapes of Deepseek-v3
136-
[
137-
# [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes
138-
]
139-
# FIXME: Reenable when PyTorch fixes config in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/template_heuristics/triton.py#L1509
140-
if x_val[-1] == 'fwd' or x_val[5] != 128
89+
x_vals=
90+
# Multi-head attention. H_q equals H_kv
91+
# Prefill shapes of Phi3-mini-4k-instruct
92+
[[z, 32, 32, 1024, 1024, 96, 96, fa_kernel_mode] for z in batch_sizes] +
93+
# Prefill shapes of Qwen3-4B
94+
[[z, 32, 32, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
95+
# Prefill shapes of DeepSeek-v3
96+
[[z, 128, 128, 1024, 1024, 192, 128, fa_kernel_mode] for z in batch_sizes] +
97+
# Append shapes of Phi3-mini-4k-instruct
98+
[[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, fa_kernel_mode] for z in batch_sizes] +
99+
100+
# Multi-query attention. H_kv equals 1.
101+
# Append shapes of Deepseek-v3
102+
([[z, 128, 1, 512, 1024 + 128 + 512, 576, 512, fa_kernel_mode]
103+
for z in batch_sizes] if fa_kernel_mode != 'bwd' else []) +
104+
105+
# Grouped-query attention. H_q / H_kv > 1
106+
# Prefill shapes of Llama-3.1-8B
107+
[[z, 32, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
108+
# Prefill shapes of meta-llama-Llama-3.2-3B
109+
[[z, 24, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
110+
# Prefill shapes of Deepseek-R1-Distill-Qwen-14B
111+
[[z, 40, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
112+
# Append shapes of Llama-3.1-8B
113+
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
114+
# Append shapes of meta-llama-Llama-3.2-3B
115+
[[z, 24, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
116+
# Append shapes of Qwen3-4B
117+
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
118+
119+
# FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
120+
# Decode shapes of Llama-3.1-8B
121+
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
122+
# Decode shapes of meta-llama-Llama-3.2-3B
123+
[[z, 24, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
124+
# Decode shapes of Phi3-mini-4k-instruct
125+
[
126+
# acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
127+
# ValueError: Shape element 2 must be a power of 2
128+
# [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes
129+
] +
130+
# Decode shapes of Qwen3-4B
131+
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
132+
# Decode shapes of Deepseek-R1-Distill-Qwen-14B
133+
[[z, 40, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
134+
# Decode shapes of Deepseek-v3
135+
[
136+
# [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes
141137
],
142138
line_arg='provider',
143139
line_vals=['triton', 'torch'],

0 commit comments

Comments
 (0)