@@ -226,27 +226,30 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
226226
227227static int mmq_get_granularity_host (ggml_type type, const int mmq_x, const int cc) {
228228 if (amd_mma_available (cc)) {
229- // 32x32 tile_C -> 32 ; 16x16 tile_C -> 16
230229 switch (type) {
231- case GGML_TYPE_Q4_0: return 16 ; // vec_dot_q8_0_q8_1_mma
232- case GGML_TYPE_Q4_1: return 16 ; // vec_dot_q8_1_q8_1_mma
233- case GGML_TYPE_Q5_0: return 16 ; // vec_dot_q8_0_q8_1_mma
234- case GGML_TYPE_Q5_1: return 16 ; // vec_dot_q8_1_q8_1_mma
235- case GGML_TYPE_Q8_0: return 16 ; // vec_dot_q8_0_q8_1_mma
236- case GGML_TYPE_Q2_K: return 32 ; // vec_dot_q2_K_q8_1_mma
237- case GGML_TYPE_Q3_K: return 32 ; // vec_dot_q8_0_16_q8_1_mma
238- case GGML_TYPE_Q4_K: return 16 ; // vec_dot_q8_1_q8_1_mma
239- case GGML_TYPE_Q5_K: return 16 ; // vec_dot_q8_1_q8_1_mma
240- case GGML_TYPE_Q6_K: return 32 ; // vec_dot_q6_K_q8_1_mma
241- case GGML_TYPE_IQ2_XXS: return 16 ; // vec_dot_q8_0_q8_1_mma
242- case GGML_TYPE_IQ2_XS: return 32 ; // vec_dot_q8_0_16_q8_1_mma
243- case GGML_TYPE_IQ2_S: return 32 ; // vec_dot_q8_0_16_q8_1_mma
244- case GGML_TYPE_IQ3_XXS: return 16 ; // vec_dot_q8_0_q8_1_mma
245- case GGML_TYPE_IQ3_S: return 16 ; // vec_dot_q8_0_q8_1_mma
246- case GGML_TYPE_IQ1_S: return 16 ; // vec_dot_q8_1_q8_1_mma
247- case GGML_TYPE_IQ4_XS: return 16 ; // vec_dot_q8_0_q8_1_mma
248- case GGML_TYPE_IQ4_NL: return 16 ; // vec_dot_q8_0_q8_1_mma
249- default : return 0 ;
230+ // vec_dot_q8_0_q8_1_mma
231+ case GGML_TYPE_Q4_0:
232+ case GGML_TYPE_Q5_0:
233+ case GGML_TYPE_Q8_0:
234+ case GGML_TYPE_IQ2_XXS:
235+ case GGML_TYPE_IQ3_XXS:
236+ case GGML_TYPE_IQ3_S:
237+ case GGML_TYPE_IQ4_XS:
238+ case GGML_TYPE_IQ4_NL:
239+ return mmq_x >= 128 ? 32 : 16 ;
240+ // vec_dot_q8_1_q8_1_mma
241+ case GGML_TYPE_Q4_1:
242+ case GGML_TYPE_Q5_1:
243+ case GGML_TYPE_Q4_K:
244+ case GGML_TYPE_Q5_K:
245+ case GGML_TYPE_IQ1_S:
246+ return mmq_x >= 128 ? 32 : 16 ;
247+ case GGML_TYPE_Q2_K: // vec_dot_q2_K_q8_1_mma
248+ case GGML_TYPE_Q3_K: // vec_dot_q8_0_16_q8_1_mma
249+ case GGML_TYPE_Q6_K: // vec_dot_q6_K_q8_1_mma
250+ case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
251+ case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
252+ return mmq_x >= 192 ? 64 : 32 ;
250253 }
251254 } else if (new_mma_available (cc) && mmq_x >= 48 ) {
252255 return 16 ;
@@ -256,26 +259,29 @@ static int mmq_get_granularity_host(ggml_type type, const int mmq_x, const int c
256259#if defined(AMD_MMA_AVAILABLE)
257260static constexpr __device__ int mmq_get_granularity_device (ggml_type type, const int mmq_x) {
258261 switch (type) {
259- // 32x32 tile_C -> 32 ; 16x16 tile_C -> 16
260- case GGML_TYPE_Q4_0: return 16 ;
261- case GGML_TYPE_Q4_1: return 16 ;
262- case GGML_TYPE_Q5_0: return 16 ;
263- case GGML_TYPE_Q5_1: return 16 ;
264- case GGML_TYPE_Q8_0: return 16 ;
265- case GGML_TYPE_Q2_K: return 32 ;
266- case GGML_TYPE_Q3_K: return 32 ;
267- case GGML_TYPE_Q4_K: return 16 ;
268- case GGML_TYPE_Q5_K: return 16 ;
269- case GGML_TYPE_Q6_K: return 32 ;
270- case GGML_TYPE_IQ2_XXS: return 16 ;
271- case GGML_TYPE_IQ2_XS: return 32 ;
272- case GGML_TYPE_IQ2_S: return 32 ;
273- case GGML_TYPE_IQ3_XXS: return 16 ;
274- case GGML_TYPE_IQ3_S: return 16 ;
275- case GGML_TYPE_IQ1_S: return 16 ;
276- case GGML_TYPE_IQ4_XS: return 16 ;
277- case GGML_TYPE_IQ4_NL: return 16 ;
278- default : return 0 ;
262+ // vec_dot_q8_0_q8_1_mma
263+ case GGML_TYPE_Q4_0:
264+ case GGML_TYPE_Q5_0:
265+ case GGML_TYPE_Q8_0:
266+ case GGML_TYPE_IQ2_XXS:
267+ case GGML_TYPE_IQ3_XXS:
268+ case GGML_TYPE_IQ3_S:
269+ case GGML_TYPE_IQ4_XS:
270+ case GGML_TYPE_IQ4_NL:
271+ return mmq_x >= 128 ? 32 : 16 ;
272+ // vec_dot_q8_1_q8_1_mma
273+ case GGML_TYPE_Q4_1:
274+ case GGML_TYPE_Q5_1:
275+ case GGML_TYPE_Q4_K:
276+ case GGML_TYPE_Q5_K:
277+ case GGML_TYPE_IQ1_S:
278+ return mmq_x >= 128 ? 32 : 16 ;
279+ case GGML_TYPE_Q2_K: // vec_dot_q2_K_q8_1_mma
280+ case GGML_TYPE_Q3_K: // vec_dot_q8_0_16_q8_1_mma
281+ case GGML_TYPE_Q6_K: // vec_dot_q6_K_q8_1_mma
282+ case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
283+ case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
284+ return mmq_x >= 192 ? 64 : 32 ;
279285 }
280286}
281287#elif defined(NEW_MMA_AVAILABLE)
@@ -290,27 +296,30 @@ static constexpr __device__ int mmq_get_granularity_device(ggml_type type, const
290296
291297static int get_mmq_nwarps_host (ggml_type type, const int cc) {
292298 if (amd_mma_available (cc)) {
293- // 32x32 tile_C -> 4 ; 16x16 tile_C -> 8
294299 switch (type) {
295- case GGML_TYPE_Q4_0: return 8 ;
296- case GGML_TYPE_Q4_1: return 8 ;
297- case GGML_TYPE_Q5_0: return 8 ;
298- case GGML_TYPE_Q5_1: return 8 ;
299- case GGML_TYPE_Q8_0: return 8 ;
300- case GGML_TYPE_Q2_K: return 4 ;
301- case GGML_TYPE_Q3_K: return 4 ;
302- case GGML_TYPE_Q4_K: return 8 ;
303- case GGML_TYPE_Q5_K: return 8 ;
304- case GGML_TYPE_Q6_K: return 4 ;
305- case GGML_TYPE_IQ2_XXS: return 8 ;
306- case GGML_TYPE_IQ2_XS: return 4 ;
307- case GGML_TYPE_IQ2_S: return 4 ;
308- case GGML_TYPE_IQ3_XXS: return 8 ;
309- case GGML_TYPE_IQ3_S: return 8 ;
310- case GGML_TYPE_IQ1_S: return 8 ;
311- case GGML_TYPE_IQ4_XS: return 8 ;
312- case GGML_TYPE_IQ4_NL: return 8 ;
313- default : return 0 ;
300+ // vec_dot_q8_0_q8_1_mma
301+ case GGML_TYPE_Q4_0:
302+ case GGML_TYPE_Q5_0:
303+ case GGML_TYPE_Q8_0:
304+ case GGML_TYPE_IQ2_XXS:
305+ case GGML_TYPE_IQ3_XXS:
306+ case GGML_TYPE_IQ3_S:
307+ case GGML_TYPE_IQ4_XS:
308+ case GGML_TYPE_IQ4_NL:
309+ return 8 ;
310+ // vec_dot_q8_1_q8_1_mma
311+ case GGML_TYPE_Q4_1:
312+ case GGML_TYPE_Q5_1:
313+ case GGML_TYPE_Q4_K:
314+ case GGML_TYPE_Q5_K:
315+ case GGML_TYPE_IQ1_S:
316+ return 8 ;
317+ case GGML_TYPE_Q2_K: // vec_dot_q2_K_q8_1_mma
318+ case GGML_TYPE_Q3_K: // vec_dot_q8_0_16_q8_1_mma
319+ case GGML_TYPE_Q6_K: // vec_dot_q6_K_q8_1_mma
320+ case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
321+ case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
322+ return 4 ;
314323 }
315324 } else {
316325 return 8 ;
@@ -319,27 +328,30 @@ static int get_mmq_nwarps_host(ggml_type type, const int cc) {
319328
320329#if defined(AMD_MMA_AVAILABLE)
321330static constexpr __device__ int get_mmq_nwarps_device (ggml_type type) {
322- // 32x32 tile_C -> 4 ; 16x16 tile_C -> 8
323331 switch (type) {
324- case GGML_TYPE_Q4_0: return 8 ;
325- case GGML_TYPE_Q4_1: return 8 ;
326- case GGML_TYPE_Q5_0: return 8 ;
327- case GGML_TYPE_Q5_1: return 8 ;
328- case GGML_TYPE_Q8_0: return 8 ;
329- case GGML_TYPE_Q2_K: return 4 ;
330- case GGML_TYPE_Q3_K: return 4 ;
331- case GGML_TYPE_Q4_K: return 8 ;
332- case GGML_TYPE_Q5_K: return 8 ;
333- case GGML_TYPE_Q6_K: return 4 ;
334- case GGML_TYPE_IQ2_XXS: return 8 ;
335- case GGML_TYPE_IQ2_XS: return 4 ;
336- case GGML_TYPE_IQ2_S: return 4 ;
337- case GGML_TYPE_IQ3_XXS: return 8 ;
338- case GGML_TYPE_IQ3_S: return 8 ;
339- case GGML_TYPE_IQ1_S: return 8 ;
340- case GGML_TYPE_IQ4_XS: return 8 ;
341- case GGML_TYPE_IQ4_NL: return 8 ;
342- default : return 0 ;
332+ // vec_dot_q8_0_q8_1_mma
333+ case GGML_TYPE_Q4_0:
334+ case GGML_TYPE_Q5_0:
335+ case GGML_TYPE_Q8_0:
336+ case GGML_TYPE_IQ2_XXS:
337+ case GGML_TYPE_IQ3_XXS:
338+ case GGML_TYPE_IQ3_S:
339+ case GGML_TYPE_IQ4_XS:
340+ case GGML_TYPE_IQ4_NL:
341+ return 8 ;
342+ // vec_dot_q8_1_q8_1_mma
343+ case GGML_TYPE_Q4_1:
344+ case GGML_TYPE_Q5_1:
345+ case GGML_TYPE_Q4_K:
346+ case GGML_TYPE_Q5_K:
347+ case GGML_TYPE_IQ1_S:
348+ return 8 ;
349+ case GGML_TYPE_Q2_K: // vec_dot_q2_K_q8_1_mma
350+ case GGML_TYPE_Q3_K: // vec_dot_q8_0_16_q8_1_mma
351+ case GGML_TYPE_Q6_K: // vec_dot_q6_K_q8_1_mma
352+ case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
353+ case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
354+ return 4 ;
343355 }
344356}
345357#else
@@ -896,35 +908,49 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
896908 typedef tile<16 , 8 , int > tile_B;
897909 typedef tile<16 , 16 , int > tile_C;
898910
911+ constexpr int granularity = mmq_get_granularity_device (GGML_TYPE_Q4_K, mmq_x);
912+ constexpr int rows_per_warp = granularity;
913+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
914+
915+ y += (threadIdx .y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
916+
899917 const int * x_qs = (const int *) x;
900918 const half2 * x_dm = (const half2 *) x_qs + 2 *MMQ_TILE_NE_K;
901919 const int * y_qs = (const int *) y + 4 ;
902920 const half2 * y_dm = (const half2 *) y;
903921
904- const int i0 = threadIdx .y * tile_A::I ;
922+ const int i0 = ( threadIdx .y / ntx) * rows_per_warp ;
905923
906924 for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
907925 const int k0 = k00 + k01;
908926
909- tile_A A;
910- load_ldmatrix (A, x_qs + i0*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
927+ tile_A A[ntx];
928+ #pragma unroll
929+ for (int n = 0 ; n < ntx; ++n) {
930+ load_ldmatrix (A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
931+ }
911932
912933#pragma unroll
913- for (int j0 = 0 ; j0 < mmq_x; j0 += tile_C::J) {
934+ for (int j0 = 0 ; j0 < mmq_x; j0 += ntx* tile_C::J) {
914935 tile_B B;
915936 load_ldmatrix (B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
916937
917938 float2 dsB;
918939 const int j = j0 + tile_C::get_j (0 );
919940 dsB = __half22float2 (y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
920941
921- tile_C C;
922- mma (C, A, B);
942+ #pragma unroll
943+ for (int n = 0 ; n < ntx; ++n) {
944+ tile_C C;
945+ mma (C, A[n], B);
923946
924- for (int l = 0 ; l < tile_C::ne; ++l) {
925- float2 dmA = __half22float2 (x_dm[(i0 + tile_C::get_i (l))*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
926- sum[(j0/tile_C::J)*tile_C::ne + l] += dmA.x *dsB.x *C.x [l];
927- sum[(j0/tile_C::J)*tile_C::ne + l] += dmA.y *dsB.y ;
947+ #pragma unroll
948+ for (int l = 0 ; l < tile_C::ne; ++l) {
949+ const int i = i0 + n*tile_A::I + tile_C::get_i (l);
950+ float2 dmA = __half22float2 (x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
951+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x *dsB.x *C.x [l];
952+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y *dsB.y ;
953+ }
928954 }
929955 }
930956 }
@@ -2067,34 +2093,48 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20672093 typedef tile<32 , 4 , int > tile_B;
20682094 typedef tile<32 , 32 , int > tile_C;
20692095
2096+ constexpr int granularity = mmq_get_granularity_device (GGML_TYPE_Q6_K, mmq_x);
2097+ constexpr int rows_per_warp = granularity;
2098+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
2099+
2100+ y += (threadIdx .y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
2101+
20702102 const int * x_qs = (const int *) x;
20712103 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2 ;
20722104 const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
20732105 const int * y_qs = (const int *) y + 4 ;
20742106 const float * y_df = (const float *) y;
20752107
2076- const int i0 = threadIdx .y * tile_A::I ;
2108+ const int i0 = ( threadIdx .y / ntx) * rows_per_warp ;
20772109
20782110 for (int k01 = 0 ; k01 < MMQ_TILE_NE_K; k01 += 4 ) {
20792111 const int k0 = k00 + k01;
20802112
2081- tile_A A;
2082- load_ldmatrix (A, x_qs + i0*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2113+ tile_A A[ntx];
2114+ #pragma unroll
2115+ for (int n = 0 ; n < ntx; ++n) {
2116+ load_ldmatrix (A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2117+ }
20832118
2084- for (int j0 = 0 ; j0 < mmq_x; j0 += tile_C::J) {
2119+ for (int j0 = 0 ; j0 < mmq_x; j0 += ntx* tile_C::J) {
20852120 tile_B B;
20862121 load_ldmatrix (B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
20872122
20882123 float dB;
20892124 const int j = j0 + tile_C::get_j (0 );
20902125 dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
20912126
2092- tile_C C;
2093- mma (C, A, B);
2127+ #pragma unroll
2128+ for (int n = 0 ; n < ntx; ++n) {
2129+ tile_C C;
2130+ mma (C, A[n], B);
20942131
2095- for (int l = 0 ; l < tile_C::ne; ++l) {
2096- const int8_t * sc = (const int8_t *) (x_sc + (i0 + tile_C::get_i (l))*MMQ_MMA_TILE_X_K_Q6_K + k00/16 );
2097- sum[(j0/tile_C::J)*tile_C::ne + l] += C.x [l] * sc[k01/4 ] * x_df[(i0 + tile_C::get_i (l))*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2132+ #pragma unroll
2133+ for (int l = 0 ; l < tile_C::ne; ++l) {
2134+ const int i = i0 + n*tile_C::I + tile_C::get_i (l);
2135+ const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16 );
2136+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x [l] * sc[k01/4 ] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
2137+ }
20982138 }
20992139 }
21002140 }
@@ -2618,7 +2658,8 @@ static __device__ __forceinline__ void mmq_write_back_mma(
26182658 constexpr int nwarps = get_mmq_nwarps_device (type);
26192659
26202660#if defined(AMD_MMA_AVAILABLE)
2621- typedef tile<granularity, granularity, int > tile_C;
2661+ constexpr int tileC_IJ = mmq_get_granularity_device (type, 0 );
2662+ typedef tile<tileC_IJ, tileC_IJ, int > tile_C;
26222663 constexpr int rows_per_warp = granularity;
26232664#else
26242665 typedef tile<16 , 8 , int > tile_C;
0 commit comments