Skip to content

Commit ff60fa9

Browse files
committed
Perf: Fix Register Spilling Q6_K - Refactor kernel, launch_bound
1 parent dad79b3 commit ff60fa9

File tree

2 files changed

+61
-50
lines changed

2 files changed

+61
-50
lines changed

ggml/src/ggml-cuda/mma.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,14 +235,18 @@ namespace ggml_cuda_mma {
235235
template <typename T>
236236
static __device__ __forceinline__ void load_ldmatrix(
237237
tile<32, 4, T> & t, const T * __restrict__ xs0, const int stride) {
238-
#ifdef NEW_MMA_AVAILABLE
238+
#if defined(AMD_MMA_AVAILABLE)
239+
int64_t* xi = (int64_t*) t.x;
240+
const int64_t* xs = (int64_t*) ((const int*) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
241+
xi[0] = xs[0];
242+
#elif defined(NEW_MMA_AVAILABLE)
239243
GGML_UNUSED(t);
240244
GGML_UNUSED(xs0);
241245
GGML_UNUSED(stride);
242246
NO_DEVICE_CODE;
243247
#else
244248
load_generic(t, xs0, stride);
245-
#endif // NEW_MMA_AVAILABLE
249+
#endif // AMD_MMA_AVAILABLE
246250
}
247251

248252
template <typename T>

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 55 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,25 +1956,14 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
19561956
template <int mmq_x, int mmq_y>
19571957
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
19581958
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
1959-
constexpr int nwarps = get_mmq_nwarps_device(GGML_TYPE_Q6_K);
1959+
#if defined(NEW_MMA_AVAILABLE)
19601960

1961-
#if defined(AMD_MMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
1962-
#if defined(AMD_MMA_AVAILABLE)
1963-
typedef tile<32, 4, int> tile_A;
1964-
typedef tile<32, 4, int> tile_B;
1965-
typedef tile<32, 32, int> tile_C;
1966-
#else
19671961
typedef tile<16, 4, int> tile_A;
19681962
typedef tile< 8, 4, int> tile_B;
19691963
typedef tile<16, 8, int> tile_C;
1970-
#endif
19711964

19721965
constexpr int granularity = mmq_get_granularity_device(GGML_TYPE_Q6_K, mmq_x);
1973-
#if defined(AMD_MMA_AVAILABLE)
1974-
constexpr int rows_per_warp = granularity; // 32
1975-
#else
19761966
constexpr int rows_per_warp = 2 * granularity;
1977-
#endif
19781967
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
19791968

19801969
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
@@ -1988,13 +1977,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
19881977
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
19891978

19901979
tile_A A[ntx][8];
1991-
#if defined(AMD_MMA_AVAILABLE)
1992-
int scA[ntx][tile_C::ne][8];
1993-
float dA[ntx][tile_C::ne];
1994-
#else
19951980
int scA[ntx][tile_C::ne/2][8];
19961981
float dA[ntx][tile_C::ne/2];
1997-
#endif
19981982

19991983
#pragma unroll
20001984
for (int n = 0; n < ntx; ++n) {
@@ -2011,13 +1995,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20111995
const int k0 = k00 + k01;
20121996

20131997
#pragma unroll
2014-
#if defined(AMD_MMA_AVAILABLE)
2015-
for (int l = 0; l < tile_C::ne; ++l) {
2016-
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2017-
#else
20181998
for (int l = 0; l < tile_C::ne/2; ++l) {
20191999
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
2020-
#endif
2000+
20212001
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
20222002
const int8_t * sc = (const int8_t *) &sc_packed;
20232003

@@ -2029,13 +2009,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20292009
}
20302010

20312011
#pragma unroll
2032-
#if defined(AMD_MMA_AVAILABLE)
2033-
for (int l = 0; l < tile_C::ne; ++l) {
2034-
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
2035-
#else
20362012
for (int l = 0; l < tile_C::ne/2; ++l) {
20372013
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
2038-
#endif
2014+
20392015
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
20402016
}
20412017
}
@@ -2047,29 +2023,18 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20472023
#pragma unroll
20482024
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
20492025
tile_B B[2];
2026+
float dB[tile_C::ne/2];
20502027

2051-
#if defined(AMD_MMA_AVAILABLE)
2052-
load_ldmatrix(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
2053-
load_ldmatrix(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
2054-
#else
20552028
// Here load_generic is faster than load_ldmatrix.
20562029
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
20572030
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
2058-
#endif
20592031

2060-
#if defined(AMD_MMA_AVAILABLE)
2061-
float dB;
2062-
const int j = j0 + tile_C::get_j(0);
2063-
dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
2064-
#else
2065-
float dB[tile_C::ne/2];
20662032
#pragma unroll
20672033
for (int l = 0; l < tile_C::ne/2; ++l) {
20682034
const int j = j0 + tile_C::get_j(l);
20692035

20702036
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
20712037
}
2072-
#endif
20732038

20742039
#pragma unroll
20752040
for (int n = 0; n < ntx; ++n) {
@@ -2079,11 +2044,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20792044

20802045
#pragma unroll
20812046
for (int l = 0; l < tile_C::ne; ++l) {
2082-
#if defined(AMD_MMA_AVAILABLE)
2083-
tmp[n][l] += (C[0].x[l]*scA[n][l][k01/4 + 0] + C[1].x[l]*scA[n][l][k01/4 + 1])*dB;
2084-
#else
20852047
tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
2086-
#endif
20872048
}
20882049
}
20892050
}
@@ -2092,11 +2053,55 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20922053
for (int n = 0; n < ntx; ++n) {
20932054
#pragma unroll
20942055
for (int l = 0; l < tile_C::ne; ++l) {
2095-
#if defined(AMD_MMA_AVAILABLE)
2096-
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l];
2097-
#else
20982056
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
2099-
#endif
2057+
}
2058+
}
2059+
}
2060+
#elif defined(AMD_MMA_AVAILABLE)
2061+
typedef tile<32, 4, int> tile_A;
2062+
typedef tile<32, 4, int> tile_B;
2063+
typedef tile<32, 32, int> tile_C;
2064+
2065+
const int * x_qs = (const int *) x;
2066+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
2067+
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
2068+
const int * y_qs = (const int *) y + 4;
2069+
const float * y_df = (const float *) y;
2070+
2071+
const int i0 = threadIdx.y * tile_A::I;
2072+
2073+
int scA[tile_C::ne][2];
2074+
float dA[tile_C::ne];
2075+
2076+
#pragma unroll
2077+
for (int l = 0; l < tile_C::ne; ++l) {
2078+
const int i = i0 + tile_C::get_i(l);
2079+
scA[l][0] = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k00/16 + 0];
2080+
scA[l][1] = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k00/16 + 1];
2081+
dA[l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
2082+
}
2083+
2084+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
2085+
const int k0 = k00 + k01;
2086+
2087+
tile_A A;
2088+
load_ldmatrix(A, x_qs + i0*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
2089+
2090+
#pragma unroll
2091+
for (int j0 = 0; j0 < mmq_x; j0 += tile_C::J) {
2092+
tile_B B;
2093+
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
2094+
2095+
float dB;
2096+
const int j = j0 + tile_C::get_j(0);
2097+
dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
2098+
2099+
tile_C C;
2100+
mma(C, A, B);
2101+
2102+
for (int l = 0; l < tile_C::ne; ++l) {
2103+
const int8_t * sc = (const int8_t *) scA[l];
2104+
sum[(j0/tile_C::J)*tile_C::ne + l] += C.x[l] * sc[k01/4] * dA[l] * dB;
21002105
}
21012106
}
21022107
}
@@ -2883,7 +2888,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
28832888

28842889
template <ggml_type type, int mmq_x, int warp_size, bool need_check>
28852890
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2886-
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2891+
#if defined(AMD_MMA_AVAILABLE)
2892+
__launch_bounds__(warp_size*get_mmq_nwarps_device(type), 1)
2893+
#elif defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA2) || defined(CDNA1) || defined(GCN)
28872894
__launch_bounds__(warp_size*get_mmq_nwarps_device(type), 2)
28882895
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
28892896
#else

0 commit comments

Comments
 (0)