@@ -143,10 +143,13 @@ static constexpr __device__ int get_mmq_y_device() {
143143#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
144144}
145145
146- // Decouple sizes from WARP_SIZE to allow for different warp sizes.
147- // MMQ_TILE_NE_K is the number of 32 bit elements in the K dimension
148- // which is treated as a single fundamental block. Bigger blocks are
149- // multiples of this size (excluding scales/padding).
146+ // Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
147+ // The K dimension of the tiles has either,
148+ // 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
149+ // 32 bit elements for the quantized data (does not include scales).
150+ // In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
151+ // The final tile size in K direction is padded to avoid shared memory bank conflicts,
152+ // in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
150153#define MMQ_TILE_NE_K 32
151154
152155#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0 }
@@ -220,7 +223,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
220223 }
221224}
222225
223- // block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit factors )
226+ // block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales )
224227#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
225228
226229static int mmq_get_granularity_host (ggml_type type, const int mmq_x, const int cc) {
@@ -238,6 +241,7 @@ static int mmq_get_granularity_host(ggml_type type, const int mmq_x, const int c
238241 // vec_dot_q8_1_q8_1_mma
239242 case GGML_TYPE_Q4_1:
240243 case GGML_TYPE_Q5_1:
244+ case GGML_TYPE_Q8_1:
241245 case GGML_TYPE_Q4_K:
242246 case GGML_TYPE_Q5_K:
243247 case GGML_TYPE_IQ1_S:
@@ -273,6 +277,7 @@ static constexpr __device__ int mmq_get_granularity_device(ggml_type type, const
273277 // vec_dot_q8_1_q8_1_mma
274278 case GGML_TYPE_Q4_1:
275279 case GGML_TYPE_Q5_1:
280+ case GGML_TYPE_Q8_1:
276281 case GGML_TYPE_Q4_K:
277282 case GGML_TYPE_Q5_K:
278283 case GGML_TYPE_IQ1_S:
@@ -873,7 +878,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
873878#pragma unroll
874879 for (int l = 0 ; l < tile_C::ne; ++l) {
875880 const int i = i0 + n*tile_A::I + tile_C::get_i (l);
876- float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
881+ const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
877882 sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x [l]*dA*dB;
878883 }
879884 }
@@ -888,7 +893,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
888893 constexpr int rows_per_warp = 2 * granularity;
889894 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
890895
891- y += (threadIdx .y % ntx) * (tile_B::I *MMQ_TILE_Y_K);
896+ y += (threadIdx .y % ntx) * (tile_C::J *MMQ_TILE_Y_K);
892897
893898 const int * x_qs = (const int *) x;
894899 const float * x_df = (const float *) x_qs + 2 *MMQ_TILE_NE_K;
@@ -998,7 +1003,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
9981003 typedef tile<16 , 8 , int > tile_B;
9991004 typedef tile<16 , 16 , int > tile_C;
10001005
1001- constexpr int granularity = mmq_get_granularity_device (GGML_TYPE_Q4_K , mmq_x);
1006+ constexpr int granularity = mmq_get_granularity_device (GGML_TYPE_Q8_1 , mmq_x);
10021007 constexpr int rows_per_warp = granularity;
10031008 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
10041009
@@ -1048,7 +1053,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
10481053 typedef tile< 8 , 8 , int > tile_B;
10491054 typedef tile<16 , 8 , int > tile_C;
10501055
1051- constexpr int granularity = mmq_get_granularity_device (GGML_TYPE_Q8_0 , mmq_x);
1056+ constexpr int granularity = mmq_get_granularity_device (GGML_TYPE_Q8_1 , mmq_x);
10521057 constexpr int rows_per_warp = 2 * granularity;
10531058 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
10541059
@@ -1118,6 +1123,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
11181123#endif // defined(AMD_MFMA_AVAILABLE)
11191124}
11201125
1126+ // Used for Q3_K, IQ2_S, and IQ2_XS
11211127template <int mmq_x, int mmq_y>
11221128static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a (
11231129 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -1152,6 +1158,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
11521158 }
11531159}
11541160
1161+ // Used for Q3_K, IQ2_S, and IQ2_XS:
11551162template <int mmq_x, int mmq_y>
11561163static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma (
11571164 const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -1164,7 +1171,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
11641171 constexpr int rows_per_warp = granularity;
11651172 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
11661173
1167- y += (threadIdx .y % ntx) * (tile_B::I *MMQ_TILE_Y_K);
1174+ y += (threadIdx .y % ntx) * (tile_C::J *MMQ_TILE_Y_K);
11681175
11691176 const int * x_qs = (const int *) x;
11701177 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2 ;
@@ -1214,7 +1221,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
12141221 constexpr int rows_per_warp = 2 * granularity;
12151222 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
12161223
1217- y += (threadIdx .y % ntx) * (tile_B::I *MMQ_TILE_Y_K);
1224+ y += (threadIdx .y % ntx) * (tile_C::J *MMQ_TILE_Y_K);
12181225
12191226 const int * x_qs = (const int *) x;
12201227 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2 ;
@@ -1420,7 +1427,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
14201427 constexpr int rows_per_warp = granularity;
14211428 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
14221429
1423- y += (threadIdx .y % ntx) * (tile_B::I *MMQ_TILE_Y_K);
1430+ y += (threadIdx .y % ntx) * (tile_C::J *MMQ_TILE_Y_K);
14241431
14251432 const int * x_qs = (const int *) x;
14261433 const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2 ;
@@ -1487,7 +1494,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
14871494 constexpr int rows_per_warp = 2 * granularity;
14881495 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
14891496
1490- y += (threadIdx .y % ntx) * (tile_B::I *MMQ_TILE_Y_K);
1497+ y += (threadIdx .y % ntx) * (tile_C::J *MMQ_TILE_Y_K);
14911498
14921499 const int * x_qs = (const int *) x;
14931500 const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2 ;
@@ -1972,7 +1979,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
19721979
19731980 const half2 dm = bxi->dm * make_half2 (1 .0f , -1 .0f );
19741981
1975- #pragma unroll
1982+ #pragma unroll
19761983 for (int l = 0 ; l < int (sizeof (int )); ++l) {
19771984 x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof (int )*ksc + l] = dm*make_half2 (sc8[l], m8[l]);
19781985 }
@@ -2181,7 +2188,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
21812188 constexpr int rows_per_warp = granularity;
21822189 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
21832190
2184- y += (threadIdx .y % ntx) * (tile_B::I *MMQ_TILE_Y_K);
2191+ y += (threadIdx .y % ntx) * (tile_C::J *MMQ_TILE_Y_K);
21852192
21862193 const int * x_qs = (const int *) x;
21872194 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2 ;
@@ -2232,7 +2239,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
22322239 constexpr int rows_per_warp = 2 * granularity;
22332240 constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
22342241
2235- y += (threadIdx .y % ntx) * (tile_B::I *MMQ_TILE_Y_K);
2242+ y += (threadIdx .y % ntx) * (tile_C::J *MMQ_TILE_Y_K);
22362243
22372244 const int * x_qs = (const int *) x;
22382245 const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2 ;
@@ -2410,7 +2417,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
24102417
24112418 constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2 ;
24122419 constexpr int nrows = warp_size / threads_per_row;
2413- const int kqsx = threadIdx .x % threads_per_row;
2420+ const int kqsx = warp_size > threads_per_row ? threadIdx .x % threads_per_row : threadIdx . x ;
24142421
24152422#pragma unroll
24162423 for (int i0 = 0 ; i0 < mmq_y; i0 += nwarps * nrows) {
0 commit comments