Skip to content

Commit 8d6528f

Browse files
[FlexAttn] Update model shapes (#5351)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 556d42f commit 8d6528f

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -82,52 +82,52 @@ def causal_mask(_, __, q_idx, kv_idx):
8282
x_names=['Z', 'H_q', 'H_kv', 'N_CTX_q', 'N_CTX_kv', 'D_HEAD_qk', 'D_HEAD_v', 'MODE'],
8383
x_vals=
8484
# Multi-head attention. H_q equals H_kv
85-
# Prefill shapes of Phi3-mini-3.8B
85+
# Prefill shapes of Phi3-mini-4k-instruct
8686
[[z, 32, 32, 1024, 1024, 96, 96, fa_kernel_mode] for z in batch_sizes] +
87-
# Prefill shapes of Deepseek-v3
87+
# Prefill shapes of Qwen3-4B
88+
[[z, 32, 32, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
89+
# Prefill shapes of DeepSeek-v3
8890
[[z, 128, 128, 1024, 1024, 192, 128, fa_kernel_mode] for z in batch_sizes] +
89-
# Append shapes of Phi3-mini-3.8B
91+
# Append shapes of Phi3-mini-4k-instruct
9092
[[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, fa_kernel_mode] for z in batch_sizes] +
9193
92-
# Multi-query attention. H_kv equals 1.
93-
# Append shapes of Deepseek-v3 (Nope)
94-
[[z, 128, 1, 512, 1024 + 128 + 512, 64, 512, fa_kernel_mode] for z in batch_sizes] +
95-
# Append shapes of Deepseek-v3 (Rope)
96-
[] +
94+
## Multi-query attention. H_kv equals 1.
95+
# Append shapes of Deepseek-v3
96+
[[z, 128, 1, 512, 1024 + 128 + 512, 576, 512, fa_kernel_mode] for z in batch_sizes] +
9797
9898
# Grouped-query attention. H_q / H_kv > 1
9999
# Prefill shapes of Llama-3.1-8B
100100
[[z, 32, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
101-
# Prefill shapes of Qwen2-7B
102-
[[z, 28, 4, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
101+
# Prefill shapes of meta-llama-Llama-3.2-3B
102+
[[z, 24, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
103+
# Prefill shapes of Deepseek-R1-Distill-Qwen-14B
104+
[[z, 40, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
103105
# Append shapes of Llama-3.1-8B
104106
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
105-
# Append shapes of Qwen2-7B
106-
[[z, 28, 4, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
107+
# Append shapes of meta-llama-Llama-3.2-3B
108+
[[z, 24, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
109+
# Append shapes of Qwen3-4B
110+
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
107111
108112
# FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
109113
# Decode shapes of Llama-3.1-8B
110114
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
111-
# Decode shapes of Phi3-mini-3.8B
115+
# Decode shapes of meta-llama-Llama-3.2-3B
116+
[[z, 24, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
117+
# Decode shapes of Phi3-mini-4k-instruct
112118
[
113119
# acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
114120
# ValueError: Shape element 2 must be a power of 2
115121
# [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes
116122
] +
117-
# Decode shapes of Qwen2-7B
118-
[
119-
# torch._inductor.exc.InductorError: LoweringException: ValueError: Number of shared query heads sharing the same KV head must be power of 2.
120-
# [z, 28, 4, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes
121-
] +
122-
# Decode shapes of Deepseek-v3 (Nope)
123+
# Decode shapes of Qwen3-4B
124+
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
125+
# Decode shapes of Deepseek-R1-Distill-Qwen-14B
126+
[[z, 40, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
127+
# Decode shapes of Deepseek-v3
123128
[
124-
# There is an known issue in IGC for kernel with extreme register pressure.
125-
# Enable this case later with new IGC.
126-
# RuntimeError: ZE_RESULT_ERROR_INVALID_KERNEL_NAME
127-
# [z, 128, 1, 1, 1024, 64, 512, fa_kernel_mode] for z in batch_sizes
128-
] +
129-
# Decode shapes of Deepseek-v3 (Rope)
130-
[],
129+
# [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes
130+
],
131131
line_arg='provider',
132132
line_vals=['triton', 'torch'],
133133
line_names=['Triton', 'Torch'],

0 commit comments

Comments
 (0)