@@ -159,27 +159,29 @@ static constexpr __device__ int get_mmq_y_device() {
159159#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8 }
160160
161161static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes (ggml_type type, int mmq_y) {
162- return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
163- type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
164- type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
165- type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
166- type == GGML_TYPE_Q6_0 ? MMQ_DP4A_TXS_Q8_0 :
167- type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
168- type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
169- type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
170- type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
171- type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
172- type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
173- type == GGML_TYPE_TQ2_0 ? MMQ_DP4A_TXS_Q8_0 :
174- type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
175- type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
176- type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
177- type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
178- type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
179- type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
180- type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
181- type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
182- tile_x_sizes{0 , 0 , 0 };
162+ switch (type) {
163+ case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
164+ case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
165+ case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
166+ case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
167+ case GGML_TYPE_Q6_0: return MMQ_DP4A_TXS_Q8_0;
168+ case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
169+ case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
170+ case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
171+ case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
172+ case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K;
173+ case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K;
174+ case GGML_TYPE_TQ2_0: return MMQ_DP4A_TXS_Q8_0;
175+ case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
176+ case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16;
177+ case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16;
178+ case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
179+ case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0;
180+ case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
181+ case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
182+ case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
183+ default : return tile_x_sizes{0 , 0 , 0 };
184+ }
183185}
184186
185187#define MMQ_MMA_TILE_X_K_Q8_0 (2 *WARP_SIZE + 2 *WARP_SIZE/QI8_0 + 4 )
@@ -195,27 +197,29 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
195197static_assert (MMQ_MMA_TILE_X_K_Q6_K % 8 == 4 , " Wrong padding." );
196198
197199static constexpr __host__ __device__ int mmq_get_mma_tile_x_k (ggml_type type) {
198- return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
199- type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
200- type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
201- type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
202- type == GGML_TYPE_Q6_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
203- type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
204- type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
205- type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
206- type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
207- type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
208- type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
209- type == GGML_TYPE_TQ2_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
210- type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
211- type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
212- type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
213- type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
214- type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
215- type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
216- type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
217- type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
218- 0 ;
200+ switch (type) {
201+ case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
202+ case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
203+ case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
204+ case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
205+ case GGML_TYPE_Q6_0: return MMQ_MMA_TILE_X_K_Q8_0;
206+ case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
207+ case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
208+ case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
209+ case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
210+ case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1;
211+ case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K;
212+ case GGML_TYPE_TQ2_0: return MMQ_MMA_TILE_X_K_Q8_0;
213+ case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
214+ case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K;
215+ case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K;
216+ case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
217+ case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0;
218+ case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
219+ case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
220+ case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
221+ default : return 0 ;
222+ }
219223}
220224
221225#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
0 commit comments