@@ -86,53 +86,58 @@ def causal_mask(_, __, q_idx, kv_idx):
8686@benchmark_suite .perf_report (
8787 benchmark_suite .Benchmark (
8888 x_names = ['Z' , 'H_q' , 'H_kv' , 'N_CTX_q' , 'N_CTX_kv' , 'D_HEAD_qk' , 'D_HEAD_v' , 'MODE' ],
89- x_vals =
90- # Multi-head attention. H_q equals H_kv
91- # Prefill shapes of Phi3-mini-4k-instruct
92- [[z , 32 , 32 , 1024 , 1024 , 96 , 96 , fa_kernel_mode ] for z in batch_sizes ] +
93- # Prefill shapes of Qwen3-4B
94- [[z , 32 , 32 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
95- # Prefill shapes of DeepSeek-v3
96- [[z , 128 , 128 , 1024 , 1024 , 192 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
97- # Append shapes of Phi3-mini-4k-instruct
98- [[z , 32 , 32 , 512 , 1024 + 128 + 512 , 96 , 96 , fa_kernel_mode ] for z in batch_sizes ] +
99-
100- ## Multi-query attention. H_kv equals 1.
101- # Append shapes of Deepseek-v3
102- [[z , 128 , 1 , 512 , 1024 + 128 + 512 , 576 , 512 , fa_kernel_mode ] for z in batch_sizes ] +
103-
104- # Grouped-query attention. H_q / H_kv > 1
105- # Prefill shapes of Llama-3.1-8B
106- [[z , 32 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
107- # Prefill shapes of meta-llama-Llama-3.2-3B
108- [[z , 24 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
109- # Prefill shapes of Deepseek-R1-Distill-Qwen-14B
110- [[z , 40 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
111- # Append shapes of Llama-3.1-8B
112- [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
113- # Append shapes of meta-llama-Llama-3.2-3B
114- [[z , 24 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
115- # Append shapes of Qwen3-4B
116- [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
117-
118- # FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
119- # Decode shapes of Llama-3.1-8B
120- [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
121- # Decode shapes of meta-llama-Llama-3.2-3B
122- [[z , 24 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
123- # Decode shapes of Phi3-mini-4k-instruct
124- [
125- # acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
126- # ValueError: Shape element 2 must be a power of 2
127- # [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes
128- ] +
129- # Decode shapes of Qwen3-4B
130- [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
131- # Decode shapes of Deepseek-R1-Distill-Qwen-14B
132- [[z , 40 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
133- # Decode shapes of Deepseek-v3
134- [
135- # [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes
89+ x_vals = [
90+ x_val for x_val in
91+ # Multi-head attention. H_q equals H_kv
92+ # Prefill shapes of Phi3-mini-4k-instruct
93+ [[z , 32 , 32 , 1024 , 1024 , 96 , 96 , fa_kernel_mode ] for z in batch_sizes ] +
94+ # Prefill shapes of Qwen3-4B
95+ [[z , 32 , 32 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
96+ # Prefill shapes of DeepSeek-v3
97+ [[z , 128 , 128 , 1024 , 1024 , 192 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
98+ # Append shapes of Phi3-mini-4k-instruct
99+ [[z , 32 , 32 , 512 , 1024 + 128 + 512 , 96 , 96 , fa_kernel_mode ] for z in batch_sizes ] +
100+
101+ # Multi-query attention. H_kv equals 1.
102+ # Append shapes of Deepseek-v3
103+ ([[z , 128 , 1 , 512 , 1024 + 128 + 512 , 576 , 512 , fa_kernel_mode ]
104+ for z in batch_sizes ] if fa_kernel_mode != 'bwd' else []) +
105+
106+ # Grouped-query attention. H_q / H_kv > 1
107+ # Prefill shapes of Llama-3.1-8B
108+ [[z , 32 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
109+ # Prefill shapes of meta-llama-Llama-3.2-3B
110+ [[z , 24 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
111+ # Prefill shapes of Deepseek-R1-Distill-Qwen-14B
112+ [[z , 40 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
113+ # Append shapes of Llama-3.1-8B
114+ [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
115+ # Append shapes of meta-llama-Llama-3.2-3B
116+ [[z , 24 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
117+ # Append shapes of Qwen3-4B
118+ [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
119+
120+ # FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
121+ # Decode shapes of Llama-3.1-8B
122+ [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
123+ # Decode shapes of meta-llama-Llama-3.2-3B
124+ [[z , 24 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
125+ # Decode shapes of Phi3-mini-4k-instruct
126+ [
127+ # acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
128+ # ValueError: Shape element 2 must be a power of 2
129+ # [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes
130+ ] +
131+ # Decode shapes of Qwen3-4B
132+ [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
133+ # Decode shapes of Deepseek-R1-Distill-Qwen-14B
134+ [[z , 40 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
135+ # Decode shapes of Deepseek-v3
136+ [
137+ # [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes
138+ ]
139+ # FIXME: Reenable when PyTorch fixes config in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/template_heuristics/triton.py#L1509
140+ if x_val [- 1 ] == 'fwd' or x_val [5 ] != 128
136141 ],
137142 line_arg = 'provider' ,
138143 line_vals = ['triton' , 'torch' ],
@@ -143,6 +148,7 @@ def causal_mask(_, __, q_idx, kv_idx):
143148 args = {},
144149 ))
145150def benchmark (Z , H_q , H_kv , N_CTX_q , N_CTX_kv , D_HEAD_qk , D_HEAD_v , MODE , provider ):
151+ print (f'Running benchmark for Z={ Z } , H_q={ H_q } , H_kv={ H_kv } , N_CTX_q={ N_CTX_q } , N_CTX_kv={ N_CTX_kv } , ' )
146152 # Maximum across torch=200, triton=600
147153 do_bench = benchmark_suite .get_do_bench (n_warmup = 600 , n_repeat = 10 , quantiles = [0.5 , 0.0 , 1.0 ])
148154 if MODE not in ('fwd' , 'bwd' ):
0 commit comments