Skip to content

Commit 4f2f1b7

Browse files
committed
Slightly better q8_0_q8_1 kernel and iqk_ks tile loading
And "minor" update for iq4_ks
1 parent 47d2663 commit 4f2f1b7

File tree

1 file changed

+114
-66
lines changed

1 file changed

+114
-66
lines changed

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 114 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -862,61 +862,61 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
862862
const float * y_df = (const float *) y;
863863
const half2 * y_ds = (const half2 *) y;
864864

865-
tile_A A[ntx][WARP_SIZE/QI8_0];
866-
float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
865+
tile_A A[ntx];
866+
float dA[ntx][tile_C::ne/2];
867867

868868
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
869869

870-
#pragma unroll
871-
for (int n = 0; n < ntx; ++n) {
872-
#pragma unroll
873-
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
874-
const int k0 = k00 + k01;
875-
876-
load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
877-
}
878-
879-
#pragma unroll
870+
#pragma unroll
871+
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
872+
const int k0 = k00 + k01;
873+
tile_B B;
874+
float dB[tile_C::ne/2];
875+
load_generic(B, y_qs + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
876+
#pragma unroll
880877
for (int l = 0; l < tile_C::ne/2; ++l) {
881-
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
882-
883-
#pragma unroll
884-
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
885-
const int k0 = k00 + k01;
886-
887-
dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
878+
const int j = tile_C::get_j(l);
879+
if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
880+
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
881+
} else {
882+
dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
888883
}
889884
}
890-
}
891-
892-
#pragma unroll
893-
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
894-
#pragma unroll
895-
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
896-
tile_B B;
897-
float dB[tile_C::ne/2];
898-
885+
#pragma unroll
886+
for (int n = 0; n < ntx; ++n) {
887+
load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
888+
#pragma unroll
889+
for (int l = 0; l < tile_C::ne/2; ++l) {
890+
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
891+
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
892+
}
893+
tile_C C;
894+
mma(C, A[n], B);
895+
#pragma unroll
896+
for (int l = 0; l < tile_C::ne; ++l) {
897+
sum[(n)*tile_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
898+
}
899+
}
900+
#pragma unroll
901+
for (int j0 = ntx*tile_C::J; j0 < mmq_x; j0 += ntx*tile_C::J) {
899902
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
900-
901-
#pragma unroll
903+
#pragma unroll
902904
for (int l = 0; l < tile_C::ne/2; ++l) {
903905
const int j = j0 + tile_C::get_j(l);
904-
905-
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
906-
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
906+
if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
907+
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
907908
} else {
908909
dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
909910
}
910911
}
912+
#pragma unroll
911913

912-
#pragma unroll
913914
for (int n = 0; n < ntx; ++n) {
914915
tile_C C;
915-
mma(C, A[n][k01/QI8_0], B);
916-
917-
#pragma unroll
916+
mma(C, A[n], B);
917+
#pragma unroll
918918
for (int l = 0; l < tile_C::ne; ++l) {
919-
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
919+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2]*dB[l%2];
920920
}
921921
}
922922
}
@@ -2784,6 +2784,64 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
27842784
}
27852785
}
27862786

2787+
//template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks(
2788+
// const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2789+
//
2790+
//#ifdef NEW_MMA_AVAILABLE
2791+
// int * x_qs = (int *) x_tile;
2792+
// float * x_df = (float *) (x_qs + WARP_SIZE*2);
2793+
//#else
2794+
// constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2795+
// int * x_qs = (int *) x_tile;
2796+
// float * x_df = (float *) (x_qs + txs.qs);
2797+
//#endif // NEW_MMA_AVAILABLE
2798+
//
2799+
// const int kbx = 0; // threadIdx.x / QI4_XS
2800+
// const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
2801+
//
2802+
//#pragma unroll
2803+
// for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
2804+
// int i = i0 + threadIdx.y;
2805+
//
2806+
// if (need_check) {
2807+
// i = min(i, i_max);
2808+
// }
2809+
//
2810+
// const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx;
2811+
//
2812+
// auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4);
2813+
// const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2814+
// const int2 v = get_int_from_table_16(aux_q4, values);
2815+
// const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2816+
//#ifdef NEW_MMA_AVAILABLE
2817+
// x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2818+
// x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2819+
//#else
2820+
// x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2821+
// x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2822+
//#endif // NEW_MMA_AVAILABLE
2823+
// }
2824+
//
2825+
//#pragma unroll
2826+
// for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
2827+
// int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
2828+
//
2829+
// if (need_check) {
2830+
// i = min(i, i_max);
2831+
// }
2832+
//
2833+
// const float * dptr = (const float *)(x + i*stride);
2834+
// const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
2835+
// const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127;
2836+
//
2837+
//#ifdef NEW_MMA_AVAILABLE
2838+
// x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls;
2839+
//#else
2840+
// x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls;
2841+
//#endif // NEW_MMA_AVAILABLE
2842+
// }
2843+
//}
2844+
27872845
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks(
27882846
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
27892847

@@ -2796,50 +2854,40 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
27962854
float * x_df = (float *) (x_qs + txs.qs);
27972855
#endif // NEW_MMA_AVAILABLE
27982856

2799-
const int kbx = 0; // threadIdx.x / QI4_XS
2800-
const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
2857+
const int kqsx = threadIdx.x / 4;
28012858

28022859
#pragma unroll
2803-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
2804-
int i = i0 + threadIdx.y;
2860+
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
2861+
int i = i0 + 4*threadIdx.y + threadIdx.x%4;
28052862

28062863
if (need_check) {
28072864
i = min(i, i_max);
28082865
}
28092866

2810-
const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx;
2867+
const float * dptr = (const float *)(x + i*stride);
2868+
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
2869+
const int ls = (bxi->scales[kqsx] & 254) - 127;
2870+
auto values = iq4k_values + ((bxi->scales[kqsx] & 1) << 4);
28112871

2812-
auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4);
2813-
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2814-
const int2 v = get_int_from_table_16(aux_q4, values);
2815-
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2872+
#pragma unroll
2873+
for (int j = 0; j < 4; ++j) {
2874+
const int aux_q4 = get_int_b4(bxi->qs, 4*kqsx+j);
2875+
const int2 v = get_int_from_table_16(aux_q4, values);
28162876
#ifdef NEW_MMA_AVAILABLE
2817-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2818-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2877+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = v.x;
2878+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = v.y;
28192879
#else
2820-
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2821-
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2880+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = v.x;
2881+
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = v.y;
28222882
#endif // NEW_MMA_AVAILABLE
2823-
}
2824-
2825-
#pragma unroll
2826-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
2827-
int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
2828-
2829-
if (need_check) {
2830-
i = min(i, i_max);
28312883
}
2832-
2833-
const float * dptr = (const float *)(x + i*stride);
2834-
const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
2835-
const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127;
2836-
28372884
#ifdef NEW_MMA_AVAILABLE
2838-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls;
2885+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[0] * ls;
28392886
#else
2840-
x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls;
2887+
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[0] * ls;
28412888
#endif // NEW_MMA_AVAILABLE
28422889
}
2890+
28432891
}
28442892

28452893
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_kt(

0 commit comments

Comments
 (0)