@@ -160,26 +160,27 @@ static constexpr __device__ int get_mmq_y_device() {
160160#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8 }
161161
162162static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes (ggml_type type, int mmq_y) {
163- return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
164- type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
165- type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
166- type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
167- type == GGML_TYPE_Q6_0 ? MMQ_DP4A_TXS_Q8_0 :
168- type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
169- type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
170- type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
171- type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
172- type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
173- type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
174- type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
175- type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
176- type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
177- type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
178- type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
179- type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
180- type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
181- type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
182- tile_x_sizes{0 , 0 , 0 };
163+ switch (type) {
164+ case GGML_TYPE_Q4_1 : return MMQ_DP4A_TXS_Q4_1;
165+ case GGML_TYPE_Q5_0 : return MMQ_DP4A_TXS_Q8_0;
166+ case GGML_TYPE_Q5_1 : return MMQ_DP4A_TXS_Q8_1;
167+ case GGML_TYPE_Q6_0 : return MMQ_DP4A_TXS_Q8_0;
168+ case GGML_TYPE_Q8_0 : return MMQ_DP4A_TXS_Q8_0;
169+ case GGML_TYPE_Q2_K : return MMQ_DP4A_TXS_Q2_K;
170+ case GGML_TYPE_Q3_K : return MMQ_DP4A_TXS_Q3_K;
171+ case GGML_TYPE_Q4_K : return MMQ_DP4A_TXS_Q4_K;
172+ case GGML_TYPE_Q5_K : return MMQ_DP4A_TXS_Q5_K;
173+ case GGML_TYPE_Q6_K : return MMQ_DP4A_TXS_Q6_K;
174+ case GGML_TYPE_IQ2_XXS : return MMQ_DP4A_TXS_Q8_0;
175+ case GGML_TYPE_IQ2_XS : return MMQ_DP4A_TXS_Q8_0_16;
176+ case GGML_TYPE_IQ2_S : return MMQ_DP4A_TXS_Q8_0_16;
177+ case GGML_TYPE_IQ3_XXS : return MMQ_DP4A_TXS_Q8_0;
178+ case GGML_TYPE_IQ3_S : return MMQ_DP4A_TXS_Q8_0;
179+ case GGML_TYPE_IQ1_S : return MMQ_DP4A_TXS_Q8_0;
180+ case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0;
181+ case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0;
182+ default : return tile_x_sizes{0 , 0 , 0 };
183+ }
183184}
184185
185186#define MMQ_MMA_TILE_X_K_Q8_0 (2 *WARP_SIZE + 2 *WARP_SIZE/QI8_0 + 4 )
@@ -195,26 +196,28 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
195196static_assert (MMQ_MMA_TILE_X_K_Q6_K % 8 == 4 , " Wrong padding." );
196197
197198static constexpr __host__ __device__ int mmq_get_mma_tile_x_k (ggml_type type) {
198- return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
199- type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
200- type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
201- type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
202- type == GGML_TYPE_Q6_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
203- type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
204- type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
205- type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
206- type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
207- type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
208- type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
209- type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
210- type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
211- type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
212- type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
213- type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
214- type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
215- type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
216- type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
217- 0 ;
199+ switch (type) {
200+ case GGML_TYPE_Q4_0 : return MMQ_MMA_TILE_X_K_Q8_0;
201+ case GGML_TYPE_Q4_1 : return MMQ_MMA_TILE_X_K_Q8_1;
202+ case GGML_TYPE_Q5_0 : return MMQ_MMA_TILE_X_K_Q8_0;
203+ case GGML_TYPE_Q5_1 : return MMQ_MMA_TILE_X_K_Q8_1;
204+ case GGML_TYPE_Q6_0 : return MMQ_MMA_TILE_X_K_Q8_0;
205+ case GGML_TYPE_Q8_0 : return MMQ_MMA_TILE_X_K_Q8_0;
206+ case GGML_TYPE_Q2_K : return MMQ_MMA_TILE_X_K_Q2_K;
207+ case GGML_TYPE_Q3_K : return MMQ_MMA_TILE_X_K_Q3_K;
208+ case GGML_TYPE_Q4_K : return MMQ_MMA_TILE_X_K_Q8_1;
209+ case GGML_TYPE_Q5_K : return MMQ_MMA_TILE_X_K_Q8_1;
210+ case GGML_TYPE_Q6_K : return MMQ_MMA_TILE_X_K_Q6_K;
211+ case GGML_TYPE_IQ2_XXS : return MMQ_MMA_TILE_X_K_Q8_0;
212+ case GGML_TYPE_IQ2_XS : return MMQ_MMA_TILE_X_K_Q3_K;
213+ case GGML_TYPE_IQ2_S : return MMQ_MMA_TILE_X_K_Q3_K;
214+ case GGML_TYPE_IQ3_XXS : return MMQ_MMA_TILE_X_K_Q8_0;
215+ case GGML_TYPE_IQ3_S : return MMQ_MMA_TILE_X_K_Q8_0;
216+ case GGML_TYPE_IQ1_S : return MMQ_MMA_TILE_X_K_Q8_0;
217+ case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0;
218+ case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0;
219+ default : return 0 ;
220+ }
218221}
219222
220223#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
0 commit comments