Skip to content

Commit a161900

Browse files
committed
Perf: Throughput Increase 4k->6.9k t/s
1 parent e8eeb34 commit a161900

File tree

1 file changed

+10
-30
lines changed

1 file changed

+10
-30
lines changed

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -909,13 +909,6 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
909909
tile_A A;
910910
load_ldmatrix(A, x_qs + i0*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
911911

912-
float2 dmA[tile_C::ne];
913-
#pragma unroll
914-
for (int l = 0; l < tile_C::ne; ++l) {
915-
const int i = i0 + tile_C::get_i(l);
916-
dmA[l] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
917-
}
918-
919912
#pragma unroll
920913
for (int j0 = 0; j0 < mmq_x; j0 += tile_C::J) {
921914
tile_B B;
@@ -929,8 +922,9 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
929922
mma(C, A, B);
930923

931924
for (int l = 0; l < tile_C::ne; ++l) {
932-
sum[(j0/tile_C::J)*tile_C::ne + l] += dmA[l].x*dsB.x*C.x[l];
933-
sum[(j0/tile_C::J)*tile_C::ne + l] += dmA[l].y*dsB.y;
925+
float2 dmA = __half22float2(x_dm[(i0 + tile_C::get_i(l))*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
926+
sum[(j0/tile_C::J)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
927+
sum[(j0/tile_C::J)*tile_C::ne + l] += dmA.y*dsB.y;
934928
}
935929
}
936930
}
@@ -2081,24 +2075,12 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
20812075

20822076
const int i0 = threadIdx.y * tile_A::I;
20832077

2084-
int scA[tile_C::ne][2];
2085-
float dA[tile_C::ne];
2086-
2087-
#pragma unroll
2088-
for (int l = 0; l < tile_C::ne; ++l) {
2089-
const int i = i0 + tile_C::get_i(l);
2090-
scA[l][0] = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k00/16 + 0];
2091-
scA[l][1] = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k00/16 + 1];
2092-
dA[l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
2093-
}
2094-
20952078
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
20962079
const int k0 = k00 + k01;
20972080

20982081
tile_A A;
20992082
load_ldmatrix(A, x_qs + i0*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
21002083

2101-
#pragma unroll
21022084
for (int j0 = 0; j0 < mmq_x; j0 += tile_C::J) {
21032085
tile_B B;
21042086
load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
@@ -2111,8 +2093,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
21112093
mma(C, A, B);
21122094

21132095
for (int l = 0; l < tile_C::ne; ++l) {
2114-
const int8_t * sc = (const int8_t *) scA[l];
2115-
sum[(j0/tile_C::J)*tile_C::ne + l] += C.x[l] * sc[k01/4] * dA[l] * dB;
2096+
const int8_t * sc = (const int8_t *) (x_sc + (i0 + tile_C::get_i(l))*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
2097+
sum[(j0/tile_C::J)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[(i0 + tile_C::get_i(l))*MMQ_MMA_TILE_X_K_Q6_K] * dB;
21162098
}
21172099
}
21182100
}
@@ -2858,9 +2840,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
28582840
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
28592841
#pragma unroll
28602842
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
2861-
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
2843+
int l = (l0 + threadIdx.y*warp_size + threadIdx.x) % (mmq_x*MMQ_TILE_Y_K);
28622844

2863-
if (l < mmq_x*MMQ_TILE_Y_K) tile_y[l] = by0[l];
2845+
tile_y[l] = by0[l];
28642846
}
28652847
}
28662848

@@ -2874,9 +2856,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
28742856
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
28752857
#pragma unroll
28762858
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
2877-
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
2859+
int l = (l0 + threadIdx.y*warp_size + threadIdx.x) % (mmq_x*MMQ_TILE_Y_K);
28782860

2879-
if (l < mmq_x*MMQ_TILE_Y_K) tile_y[l] = by0[l];
2861+
tile_y[l] = by0[l];
28802862
}
28812863
}
28822864

@@ -2899,9 +2881,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
28992881

29002882
template <ggml_type type, int mmq_x, int warp_size, bool need_check>
29012883
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2902-
#if defined(AMD_MMA_AVAILABLE)
2903-
__launch_bounds__(warp_size*get_mmq_nwarps_device(type), 1)
2904-
#elif defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA2) || defined(CDNA1) || defined(GCN)
2884+
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1) || defined(GCN)
29052885
__launch_bounds__(warp_size*get_mmq_nwarps_device(type), 2)
29062886
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
29072887
#else

0 commit comments

Comments
 (0)