@@ -156,27 +156,25 @@ static constexpr __device__ int get_mmq_y_device() {
156156#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8 }
157157
158158static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes (ggml_type type, int mmq_y) {
159- switch (type) {
160- case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
161- case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
162- case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
163- case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
164- case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
165- case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
166- case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
167- case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
168- case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K;
169- case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K;
170- case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
171- case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16;
172- case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16;
173- case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
174- case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0;
175- case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
176- case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
177- case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
178- default : return tile_x_sizes{0 , 0 , 0 };
179- }
159+ return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
160+ type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
161+ type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
162+ type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
163+ type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
164+ type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
165+ type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
166+ type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
167+ type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
168+ type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
169+ type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
170+ type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
171+ type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
172+ type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
173+ type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
174+ type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
175+ type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
176+ type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
177+ tile_x_sizes{0 , 0 , 0 };
180178}
181179
182180#define MMQ_MMA_TILE_X_K_Q8_0 (2 *WARP_SIZE + 2 *WARP_SIZE/QI8_0 + 4 )
@@ -192,27 +190,25 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
192190static_assert (MMQ_MMA_TILE_X_K_Q6_K % 8 == 4 , " Wrong padding." );
193191
194192static constexpr __host__ __device__ int mmq_get_mma_tile_x_k (ggml_type type) {
195- switch (type) {
196- case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
197- case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
198- case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
199- case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
200- case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
201- case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
202- case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
203- case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
204- case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1;
205- case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K;
206- case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
207- case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K;
208- case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K;
209- case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
210- case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0;
211- case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
212- case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
213- case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
214- default : return 0 ;
215- }
193+ return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
194+ type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
195+ type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
196+ type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
197+ type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
198+ type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
199+ type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
200+ type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
201+ type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
202+ type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
203+ type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
204+ type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
205+ type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
206+ type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
207+ type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
208+ type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
209+ type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
210+ type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
211+ 0 ;
216212}
217213
218214#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
0 commit comments