Skip to content

Commit a2a336b

Browse files
committed
refactor: PR cleanup
1 parent 279b51e commit a2a336b

File tree

2 files changed

+30
-22
lines changed

2 files changed

+30
-22
lines changed

ggml/src/ggml-cuda/mma.cuh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
1313
// All matrix tiles have ne physical 32 bit elements per warp.
1414
//
15-
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
15+
// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
16+
// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.
1617

1718
#include "common.cuh"
1819

@@ -453,13 +454,13 @@ namespace ggml_cuda_mma {
453454
B.x[1],
454455
acc[0],
455456
0, 0, 0);
456-
#endif
457+
#endif // defined(CDNA3)
457458
#else
458459
GGML_UNUSED(D);
459460
GGML_UNUSED(A);
460461
GGML_UNUSED(B);
461462
NO_DEVICE_CODE;
462-
#endif // NEW_MMA_AVAILABLE
463+
#endif // AMD_MFMA_AVAILABLE
463464
}
464465

465466
static __device__ __forceinline__ void mma(
@@ -481,12 +482,12 @@ namespace ggml_cuda_mma {
481482
B.x[1],
482483
acc[0],
483484
0, 0, 0);
484-
#endif
485+
#endif // defined(CDNA3)
485486
#else
486487
GGML_UNUSED(D);
487488
GGML_UNUSED(A);
488489
GGML_UNUSED(B);
489490
NO_DEVICE_CODE;
490-
#endif // NEW_MMA_AVAILABLE
491+
#endif // AMD_MFMA_AVAILABLE
491492
}
492493
}

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,13 @@ static constexpr __device__ int get_mmq_y_device() {
143143
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
144144
}
145145

146-
// Decouple sizes from WARP_SIZE to allow for different warp sizes.
147-
// MMQ_TILE_NE_K is the number of 32 bit elements in the K dimension
148-
// which is treated as a single fundamental block. Bigger blocks are
149-
// multiples of this size (excluding scales/padding).
146+
// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
147+
// The K dimension of the tiles has either,
148+
// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
149+
// 32 bit elements for the quantized data (does not include scales).
150+
// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
151+
// The final tile size in K direction is padded to avoid shared memory bank conflicts,
152+
// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
150153
#define MMQ_TILE_NE_K 32
151154

152155
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0}
@@ -220,7 +223,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
220223
}
221224
}
222225

223-
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit factors)
226+
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
224227
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
225228

226229
static int mmq_get_granularity_host(ggml_type type, const int mmq_x, const int cc) {
@@ -238,6 +241,7 @@ static int mmq_get_granularity_host(ggml_type type, const int mmq_x, const int c
238241
// vec_dot_q8_1_q8_1_mma
239242
case GGML_TYPE_Q4_1:
240243
case GGML_TYPE_Q5_1:
244+
case GGML_TYPE_Q8_1:
241245
case GGML_TYPE_Q4_K:
242246
case GGML_TYPE_Q5_K:
243247
case GGML_TYPE_IQ1_S:
@@ -273,6 +277,7 @@ static constexpr __device__ int mmq_get_granularity_device(ggml_type type, const
273277
// vec_dot_q8_1_q8_1_mma
274278
case GGML_TYPE_Q4_1:
275279
case GGML_TYPE_Q5_1:
280+
case GGML_TYPE_Q8_1:
276281
case GGML_TYPE_Q4_K:
277282
case GGML_TYPE_Q5_K:
278283
case GGML_TYPE_IQ1_S:
@@ -873,7 +878,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
873878
#pragma unroll
874879
for (int l = 0; l < tile_C::ne; ++l) {
875880
const int i = i0 + n*tile_A::I + tile_C::get_i(l);
876-
float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
881+
const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
877882
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;
878883
}
879884
}
@@ -888,7 +893,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
888893
constexpr int rows_per_warp = 2 * granularity;
889894
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
890895

891-
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
896+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
892897

893898
const int * x_qs = (const int *) x;
894899
const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
@@ -998,7 +1003,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
9981003
typedef tile<16, 8, int> tile_B;
9991004
typedef tile<16, 16, int> tile_C;
10001005

1001-
constexpr int granularity = mmq_get_granularity_device(GGML_TYPE_Q4_K, mmq_x);
1006+
constexpr int granularity = mmq_get_granularity_device(GGML_TYPE_Q8_1, mmq_x);
10021007
constexpr int rows_per_warp = granularity;
10031008
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
10041009

@@ -1048,7 +1053,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
10481053
typedef tile< 8, 8, int> tile_B;
10491054
typedef tile<16, 8, int> tile_C;
10501055

1051-
constexpr int granularity = mmq_get_granularity_device(GGML_TYPE_Q8_0, mmq_x);
1056+
constexpr int granularity = mmq_get_granularity_device(GGML_TYPE_Q8_1, mmq_x);
10521057
constexpr int rows_per_warp = 2 * granularity;
10531058
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
10541059

@@ -1118,6 +1123,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
11181123
#endif // defined(AMD_MFMA_AVAILABLE)
11191124
}
11201125

1126+
// Used for Q3_K, IQ2_S, and IQ2_XS
11211127
template <int mmq_x, int mmq_y>
11221128
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
11231129
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -1152,6 +1158,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
11521158
}
11531159
}
11541160

1161+
// Used for Q3_K, IQ2_S, and IQ2_XS:
11551162
template <int mmq_x, int mmq_y>
11561163
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
11571164
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -1164,7 +1171,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
11641171
constexpr int rows_per_warp = granularity;
11651172
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
11661173

1167-
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
1174+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
11681175

11691176
const int * x_qs = (const int *) x;
11701177
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
@@ -1214,7 +1221,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
12141221
constexpr int rows_per_warp = 2 * granularity;
12151222
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
12161223

1217-
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
1224+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
12181225

12191226
const int * x_qs = (const int *) x;
12201227
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
@@ -1420,7 +1427,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
14201427
constexpr int rows_per_warp = granularity;
14211428
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
14221429

1423-
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
1430+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
14241431

14251432
const int * x_qs = (const int *) x;
14261433
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
@@ -1487,7 +1494,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
14871494
constexpr int rows_per_warp = 2 * granularity;
14881495
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
14891496

1490-
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
1497+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
14911498

14921499
const int * x_qs = (const int *) x;
14931500
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
@@ -1972,7 +1979,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
19721979

19731980
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
19741981

1975-
#pragma unroll
1982+
#pragma unroll
19761983
for (int l = 0; l < int(sizeof(int)); ++l) {
19771984
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
19781985
}
@@ -2181,7 +2188,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
21812188
constexpr int rows_per_warp = granularity;
21822189
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
21832190

2184-
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
2191+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
21852192

21862193
const int * x_qs = (const int *) x;
21872194
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
@@ -2232,7 +2239,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
22322239
constexpr int rows_per_warp = 2 * granularity;
22332240
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
22342241

2235-
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
2242+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
22362243

22372244
const int * x_qs = (const int *) x;
22382245
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
@@ -2410,7 +2417,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
24102417

24112418
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
24122419
constexpr int nrows = warp_size / threads_per_row;
2413-
const int kqsx = threadIdx.x % threads_per_row;
2420+
const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
24142421

24152422
#pragma unroll
24162423
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {

0 commit comments

Comments
 (0)