@@ -82,52 +82,52 @@ def causal_mask(_, __, q_idx, kv_idx):
8282 x_names = ['Z' , 'H_q' , 'H_kv' , 'N_CTX_q' , 'N_CTX_kv' , 'D_HEAD_qk' , 'D_HEAD_v' , 'MODE' ],
8383 x_vals =
8484 # Multi-head attention. H_q equals H_kv
85- # Prefill shapes of Phi3-mini-3.8B
85+ # Prefill shapes of Phi3-mini-4k-instruct
8686 [[z , 32 , 32 , 1024 , 1024 , 96 , 96 , fa_kernel_mode ] for z in batch_sizes ] +
87- # Prefill shapes of Deepseek-v3
87+ # Prefill shapes of Qwen3-4B
88+ [[z , 32 , 32 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
89+ # Prefill shapes of DeepSeek-v3
8890 [[z , 128 , 128 , 1024 , 1024 , 192 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
89- # Append shapes of Phi3-mini-3.8B
91+ # Append shapes of Phi3-mini-4k-instruct
9092 [[z , 32 , 32 , 512 , 1024 + 128 + 512 , 96 , 96 , fa_kernel_mode ] for z in batch_sizes ] +
9193
92- # Multi-query attention. H_kv equals 1.
93- # Append shapes of Deepseek-v3 (Nope)
94- [[z , 128 , 1 , 512 , 1024 + 128 + 512 , 64 , 512 , fa_kernel_mode ] for z in batch_sizes ] +
95- # Append shapes of Deepseek-v3 (Rope)
96- [] +
94+ ## Multi-query attention. H_kv equals 1.
95+ # Append shapes of Deepseek-v3
96+ [[z , 128 , 1 , 512 , 1024 + 128 + 512 , 576 , 512 , fa_kernel_mode ] for z in batch_sizes ] +
9797
9898 # Grouped-query attention. H_q / H_kv > 1
9999 # Prefill shapes of Llama-3.1-8B
100100 [[z , 32 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
101- # Prefill shapes of Qwen2-7B
102- [[z , 28 , 4 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
101+ # Prefill shapes of meta-llama-Llama-3.2-3B
102+ [[z , 24 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
103+ # Prefill shapes of Deepseek-R1-Distill-Qwen-14B
104+ [[z , 40 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
103105 # Append shapes of Llama-3.1-8B
104106 [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
105- # Append shapes of Qwen2-7B
106- [[z , 28 , 4 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
107+ # Append shapes of meta-llama-Llama-3.2-3B
108+ [[z , 24 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
109+ # Append shapes of Qwen3-4B
110+ [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
107111
108112 # FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
109113 # Decode shapes of Llama-3.1-8B
110114 [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
111- # Decode shapes of Phi3-mini-3.8B
115+ # Decode shapes of meta-llama-Llama-3.2-3B
116+ [[z , 24 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
117+ # Decode shapes of Phi3-mini-4k-instruct
112118 [
113119 # acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
114120 # ValueError: Shape element 2 must be a power of 2
115121 # [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes
116122 ] +
117- # Decode shapes of Qwen2-7B
118- [
119- # torch._inductor.exc.InductorError: LoweringException: ValueError: Number of shared query heads sharing the same KV head must be power of 2.
120- # [z, 28, 4, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes
121- ] +
122- # Decode shapes of Deepseek-v3 (Nope)
123+ # Decode shapes of Qwen3-4B
124+ [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
125+ # Decode shapes of Deepseek-R1-Distill-Qwen-14B
126+ [[z , 40 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
127+ # Decode shapes of Deepseek-v3
123128 [
124- # There is an known issue in IGC for kernel with extreme register pressure.
125- # Enable this case later with new IGC.
126- # RuntimeError: ZE_RESULT_ERROR_INVALID_KERNEL_NAME
127- # [z, 128, 1, 1, 1024, 64, 512, fa_kernel_mode] for z in batch_sizes
128- ] +
129- # Decode shapes of Deepseek-v3 (Rope)
130- [],
129+ # [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes
130+ ],
131131 line_arg = 'provider' ,
132132 line_vals = ['triton' , 'torch' ],
133133 line_names = ['Triton' , 'Torch' ],
0 commit comments