Skip to content

Commit e8eeb34

Browse files
committed
Perf: Refactor Q4_K, reduce register pressure
1 parent ff60fa9 commit e8eeb34

File tree

2 files changed

+50
-35
lines changed

2 files changed

+50
-35
lines changed

ggml/src/ggml-cuda/mma.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,11 @@ namespace ggml_cuda_mma {
221221
template <typename T>
222222
static __device__ __forceinline__ void load_ldmatrix(
223223
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
224-
#ifdef NEW_MMA_AVAILABLE
224+
#if defined(AMD_MMA_AVAILABLE)
225+
int64_t* xi = (int64_t*) t.x;
226+
const int64_t* xs = (int64_t*) ((const int*) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
227+
xi[0] = xs[0];
228+
#elif defined(NEW_MMA_AVAILABLE)
225229
int * xi = (int * ) t.x;
226230
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
227231
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -891,23 +891,56 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
891891
template <int mmq_x, int mmq_y>
892892
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
893893
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
894-
constexpr int nwarps = get_mmq_nwarps_device(GGML_TYPE_Q8_0);
895-
// Tile definitions
896-
typedef tile<16, 8, int> tile_A;
897894
#if defined(AMD_MMA_AVAILABLE)
895+
typedef tile<16, 8, int> tile_A;
898896
typedef tile<16, 8, int> tile_B;
899897
typedef tile<16, 16, int> tile_C;
898+
899+
const int * x_qs = (const int *) x;
900+
const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
901+
const int * y_qs = (const int *) y + 4;
902+
const half2 * y_dm = (const half2 *) y;
903+
904+
const int i0 = threadIdx.y * tile_A::I;
905+
906+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
907+
const int k0 = k00 + k01;
908+
909+
tile_A A;
910+
load_ldmatrix(A, x_qs + i0*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
911+
912+
float2 dmA[tile_C::ne];
913+
#pragma unroll
914+
for (int l = 0; l < tile_C::ne; ++l) {
915+
const int i = i0 + tile_C::get_i(l);
916+
dmA[l] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
917+
}
918+
919+
#pragma unroll
920+
for (int j0 = 0; j0 < mmq_x; j0 += tile_C::J) {
921+
tile_B B;
922+
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
923+
924+
float2 dsB;
925+
const int j = j0 + tile_C::get_j(0);
926+
dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
927+
928+
tile_C C;
929+
mma(C, A, B);
930+
931+
for (int l = 0; l < tile_C::ne; ++l) {
932+
sum[(j0/tile_C::J)*tile_C::ne + l] += dmA[l].x*dsB.x*C.x[l];
933+
sum[(j0/tile_C::J)*tile_C::ne + l] += dmA[l].y*dsB.y;
934+
}
935+
}
936+
}
900937
#else
938+
typedef tile<16, 8, int> tile_A;
901939
typedef tile< 8, 8, int> tile_B;
902940
typedef tile<16, 8, int> tile_C;
903-
#endif
904941

905942
constexpr int granularity = mmq_get_granularity_device(GGML_TYPE_Q8_0, mmq_x);
906-
#if defined(AMD_MMA_AVAILABLE)
907-
constexpr int rows_per_warp = granularity; // 16
908-
#else
909943
constexpr int rows_per_warp = 2 * granularity;
910-
#endif
911944
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
912945

913946
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
@@ -918,11 +951,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
918951
const half2 * y_dm = (const half2 *) y;
919952

920953
tile_A A[ntx][MMQ_TILE_NE_K/QI8_1];
921-
#if defined(AMD_MMA_AVAILABLE)
922-
float2 dmA[ntx][tile_C::ne][MMQ_TILE_NE_K/QI8_1];
923-
#else
924954
float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
925-
#endif
926955

927956
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
928957

@@ -936,13 +965,8 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
936965
}
937966

938967
#pragma unroll
939-
#if defined(AMD_MMA_AVAILABLE)
940-
for (int l = 0; l < tile_C::ne; ++l) {
941-
const int i = i0 + n*tile_A::I + tile_C::get_i(l);
942-
#else
943968
for (int l = 0; l < tile_C::ne/2; ++l) {
944969
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
945-
#endif
946970

947971
#pragma unroll
948972
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
@@ -958,25 +982,16 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
958982
#pragma unroll
959983
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
960984
tile_B B;
961-
#if defined(AMD_MMA_AVAILABLE)
962-
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
963-
#else
964-
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
965-
#endif
966-
967-
#if defined(AMD_MMA_AVAILABLE)
968-
float2 dsB;
969-
const int j = j0 + tile_C::get_j(0);
970-
dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
971-
#else
972985
float2 dsB[tile_C::ne/2];
986+
987+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
988+
973989
#pragma unroll
974990
for (int l = 0; l < tile_C::ne/2; ++l) {
975991
const int j = j0 + tile_C::get_j(l);
976992

977993
dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
978994
}
979-
#endif
980995

981996
#pragma unroll
982997
for (int n = 0; n < ntx; ++n) {
@@ -985,17 +1000,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
9851000

9861001
#pragma unroll
9871002
for (int l = 0; l < tile_C::ne; ++l) {
988-
#if defined(AMD_MMA_AVAILABLE)
989-
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l][k01/QI8_1].x*dsB.x*C.x[l];
990-
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l][k01/QI8_1].y*dsB.y;
991-
#else
9921003
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
9931004
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
994-
#endif
9951005
}
9961006
}
9971007
}
9981008
}
1009+
#endif // AMD_MMA_AVAILABLE
9991010
}
10001011

10011012
template <int mmq_x, int mmq_y>

0 commit comments

Comments
 (0)