@@ -155,25 +155,27 @@ static constexpr __device__ int get_mmq_y_device() {
155155#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8 }
156156
157157static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes (ggml_type type, int mmq_y) {
158- return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
159- type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
160- type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
161- type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
162- type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
163- type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
164- type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
165- type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
166- type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
167- type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
168- type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
169- type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
170- type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
171- type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
172- type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
173- type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
174- type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
175- type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
176- tile_x_sizes{0 , 0 , 0 };
158+ switch (type) {
159+ case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
160+ case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
161+ case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
162+ case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
163+ case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
164+ case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
165+ case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
166+ case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
167+ case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K;
168+ case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K;
169+ case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
170+ case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16;
171+ case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16;
172+ case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
173+ case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0;
174+ case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
175+ case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
176+ case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
177+ default : return tile_x_sizes{0 , 0 , 0 };
178+ }
177179}
178180
179181#define MMQ_MMA_TILE_X_K_Q8_0 (2 *WARP_SIZE + 2 *WARP_SIZE/QI8_0 + 4 )
@@ -189,25 +191,27 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
189191static_assert (MMQ_MMA_TILE_X_K_Q6_K % 8 == 4 , " Wrong padding." );
190192
191193static constexpr __host__ __device__ int mmq_get_mma_tile_x_k (ggml_type type) {
192- return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
193- type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
194- type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
195- type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
196- type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
197- type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
198- type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
199- type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
200- type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
201- type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
202- type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
203- type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
204- type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
205- type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
206- type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
207- type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
208- type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
209- type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
210- 0 ;
194+ switch (type) {
195+ case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
196+ case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
197+ case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
198+ case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
199+ case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
200+ case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
201+ case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
202+ case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
203+ case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1;
204+ case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K;
205+ case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
206+ case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K;
207+ case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K;
208+ case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
209+ case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0;
210+ case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
211+ case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
212+ case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
213+ default : return 0 ;
214+ }
211215}
212216
213217#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
0 commit comments