@@ -86,58 +86,54 @@ def causal_mask(_, __, q_idx, kv_idx):
8686@benchmark_suite .perf_report (
8787 benchmark_suite .Benchmark (
8888 x_names = ['Z' , 'H_q' , 'H_kv' , 'N_CTX_q' , 'N_CTX_kv' , 'D_HEAD_qk' , 'D_HEAD_v' , 'MODE' ],
89- x_vals = [
90- x_val for x_val in
91- # Multi-head attention. H_q equals H_kv
92- # Prefill shapes of Phi3-mini-4k-instruct
93- [[z , 32 , 32 , 1024 , 1024 , 96 , 96 , fa_kernel_mode ] for z in batch_sizes ] +
94- # Prefill shapes of Qwen3-4B
95- [[z , 32 , 32 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
96- # Prefill shapes of DeepSeek-v3
97- [[z , 128 , 128 , 1024 , 1024 , 192 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
98- # Append shapes of Phi3-mini-4k-instruct
99- [[z , 32 , 32 , 512 , 1024 + 128 + 512 , 96 , 96 , fa_kernel_mode ] for z in batch_sizes ] +
100-
101- # Multi-query attention. H_kv equals 1.
102- # Append shapes of Deepseek-v3
103- ([[z , 128 , 1 , 512 , 1024 + 128 + 512 , 576 , 512 , fa_kernel_mode ]
104- for z in batch_sizes ] if fa_kernel_mode != 'bwd' else []) +
105-
106- # Grouped-query attention. H_q / H_kv > 1
107- # Prefill shapes of Llama-3.1-8B
108- [[z , 32 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
109- # Prefill shapes of meta-llama-Llama-3.2-3B
110- [[z , 24 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
111- # Prefill shapes of Deepseek-R1-Distill-Qwen-14B
112- [[z , 40 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
113- # Append shapes of Llama-3.1-8B
114- [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
115- # Append shapes of meta-llama-Llama-3.2-3B
116- [[z , 24 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
117- # Append shapes of Qwen3-4B
118- [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
119-
120- # FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
121- # Decode shapes of Llama-3.1-8B
122- [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
123- # Decode shapes of meta-llama-Llama-3.2-3B
124- [[z , 24 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
125- # Decode shapes of Phi3-mini-4k-instruct
126- [
127- # acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
128- # ValueError: Shape element 2 must be a power of 2
129- # [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes
130- ] +
131- # Decode shapes of Qwen3-4B
132- [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
133- # Decode shapes of Deepseek-R1-Distill-Qwen-14B
134- [[z , 40 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
135- # Decode shapes of Deepseek-v3
136- [
137- # [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes
138- ]
139- # FIXME: Reenable when PyTorch fixes config in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/template_heuristics/triton.py#L1509
140- if x_val [- 1 ] == 'fwd' or x_val [5 ] != 128
89+ x_vals =
90+ # Multi-head attention. H_q equals H_kv
91+ # Prefill shapes of Phi3-mini-4k-instruct
92+ [[z , 32 , 32 , 1024 , 1024 , 96 , 96 , fa_kernel_mode ] for z in batch_sizes ] +
93+ # Prefill shapes of Qwen3-4B
94+ [[z , 32 , 32 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
95+ # Prefill shapes of DeepSeek-v3
96+ [[z , 128 , 128 , 1024 , 1024 , 192 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
97+ # Append shapes of Phi3-mini-4k-instruct
98+ [[z , 32 , 32 , 512 , 1024 + 128 + 512 , 96 , 96 , fa_kernel_mode ] for z in batch_sizes ] +
99+
100+ # Multi-query attention. H_kv equals 1.
101+ # Append shapes of Deepseek-v3
102+ ([[z , 128 , 1 , 512 , 1024 + 128 + 512 , 576 , 512 , fa_kernel_mode ]
103+ for z in batch_sizes ] if fa_kernel_mode != 'bwd' else []) +
104+
105+ # Grouped-query attention. H_q / H_kv > 1
106+ # Prefill shapes of Llama-3.1-8B
107+ [[z , 32 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
108+ # Prefill shapes of meta-llama-Llama-3.2-3B
109+ [[z , 24 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
110+ # Prefill shapes of Deepseek-R1-Distill-Qwen-14B
111+ [[z , 40 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
112+ # Append shapes of Llama-3.1-8B
113+ [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
114+ # Append shapes of meta-llama-Llama-3.2-3B
115+ [[z , 24 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
116+ # Append shapes of Qwen3-4B
117+ [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
118+
119+ # FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
120+ # Decode shapes of Llama-3.1-8B
121+ [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
122+ # Decode shapes of meta-llama-Llama-3.2-3B
123+ [[z , 24 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
124+ # Decode shapes of Phi3-mini-4k-instruct
125+ [
126+ # acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
127+ # ValueError: Shape element 2 must be a power of 2
128+ # [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes
129+ ] +
130+ # Decode shapes of Qwen3-4B
131+ [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
132+ # Decode shapes of Deepseek-R1-Distill-Qwen-14B
133+ [[z , 40 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
134+ # Decode shapes of Deepseek-v3
135+ [
136+ # [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes
141137 ],
142138 line_arg = 'provider' ,
143139 line_vals = ['triton' , 'torch' ],
0 commit comments