Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ void ggml_cuda_op_mul_mat_q(
case GGML_TYPE_IQ3_KS:
mul_mat_q_case<GGML_TYPE_IQ3_KS>(ctx, args, stream);
break;
case GGML_TYPE_IQ4_KSS:
mul_mat_q_case<GGML_TYPE_IQ4_KSS>(ctx, args, stream);
break;
case GGML_TYPE_IQ4_KS:
mul_mat_q_case<GGML_TYPE_IQ4_KS>(ctx, args, stream);
break;
Expand Down Expand Up @@ -209,6 +212,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ2_KL:
case GGML_TYPE_IQ3_KS:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ5_KS:
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
case GGML_TYPE_IQ2_KL:
case GGML_TYPE_IQ3_KS:
case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ4_KS:
case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ4_K:
Expand Down Expand Up @@ -205,6 +206,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ2_KL : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ3_KS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_KSS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ4_KS_R4 : return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_IQ5_KS : return MMQ_DP4A_TXS_Q8_0;
Expand Down Expand Up @@ -263,6 +265,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ2_KL : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ3_KS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_KSS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ4_KS_R4 : return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_IQ5_KS : return MMQ_MMA_TILE_X_K_Q8_0;
Expand Down Expand Up @@ -4164,6 +4167,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KL);
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KSS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS_R4);
extern DECL_MMQ_CASE(GGML_TYPE_IQ5_KS_R4);
Expand Down
68 changes: 68 additions & 0 deletions ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_kss.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include "../mmq.cuh"

template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_kss(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {

#ifdef INT8_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + WARP_SIZE*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // INT8_MMA_AVAILABLE

const int kqsx = threadIdx.x / 4;

uint32_t aux32[2];
auto a8 = (const uint8_t *)aux32;

#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += 4*nwarps) {
int i = i0 + 4*threadIdx.y + threadIdx.x%4;

if (need_check) {
i = min(i, i_max);
}

const float * dptr = (const float *)(x + i*stride);
const block_iq4_kss * bxi = (const block_iq4_kss *)(dptr + 1) + kbx0;
const uint32_t * q4 = bxi->qs + 4*kqsx;
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
uint8_t ls = (s32 | (s32 >> 15)) & 0xff;

auto values = iq4k_table + ((ls & 1) << 8);

#pragma unroll
for (int j = 0; j < 4; ++j) {
uint32_t val = q4[j] & 0xfffefffe;
val = val ^ (val >> 1);
aux32[0] = (val >> 0) & 0x0f0f0f0f;
aux32[1] = (val >> 4) & 0x0f0f0f0f;
#ifdef INT8_MMA_AVAILABLE
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
#else
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 0] = int_from_table_x(a8+0, values);
x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + j + 4] = int_from_table_x(a8+4, values);
#endif // INT8_MMA_AVAILABLE
}
#ifdef INT8_MMA_AVAILABLE
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = dptr[0] * ((ls & 254) - 127);
#else
x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = dptr[0] * ((ls & 254) - 127);
#endif // INT8_MMA_AVAILABLE
}

}


template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_KSS> {
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_kss<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
};

DECL_MMQ_CASE(GGML_TYPE_IQ4_KSS);

107 changes: 107 additions & 0 deletions ggml/src/iqk/iqk_gemm_iqk_quants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2706,6 +2706,60 @@ void iqk_convert_iq3_ks_q8_k_r8(int n, const void * vx, size_t bx, void * vy, in
}
}

void iqk_convert_iq4_kss_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_iq4_kss * x8[8];

block_q8_k_r8 * y = (block_q8_k_r8 *)vy;

__m256i values[2];
{
auto v1 = _mm_loadu_si128((const __m128i *)iq4k_values+0);
auto v2 = _mm_loadu_si128((const __m128i *)iq4k_values+1);
values[0] = MM256_SET_M128I(v1, v1);
values[1] = MM256_SET_M128I(v2, v2);
}

float drow[8];
float dnew[8];
int16_t ls[16];

__m256i xv[8];
uint32_t block[8];

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) {
const float * dptr = (const float *)((const char *)vx + (ix + k)*bx);
drow[k] = dptr[0];
x8[k] = (const block_iq4_kss *)(dptr + 1);
}
auto vd = _mm256_loadu_ps(drow);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
for (int ib32 = 0; ib32 < 8; ++ib32) {
auto val = _mm_loadu_si128((const __m128i *)x8[k][i].qs+ib32);
auto val_q = _mm_and_si128(val, _mm_set1_epi32(0xfffefffe));
val_q = _mm_xor_si128(val_q, _mm_srli_epi16(val_q, 1));
xv[ib32] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(val_q, 4), val_q), _mm256_set1_epi8(0xf));
auto q4 = x8[k][i].qs + 4*ib32;
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
uint8_t s8 = (s32 | (s32 >> 15)) & 0xff;
//auto val_s = _mm_madd_epi16(_mm_and_si128(val, _mm_set1_epi32(0x00010001)), _mm_set1_epi64x(0x0008000400020001));
ls[2*ib32+0] = ls[2*ib32+1] = ((s8 & 254) - 127);
xv[ib32] = _mm256_shuffle_epi8(values[s8 & 1], xv[ib32]);
}
dnew[k] = convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs);
}
_mm_storeu_si128((__m128i *)y[i].d, _mm256_cvtps_ph(_mm256_mul_ps(vd, _mm256_loadu_ps(dnew)), _MM_ROUND_NEAREST));
}
y += nb;
}
}

void iqk_convert_iq4_ks_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
Expand Down Expand Up @@ -3132,6 +3186,7 @@ bool iqk_convert_iqk_quants_q80_r8(int type, int n, const void * vx, size_t bx,
case GGML_TYPE_IQ2_KL : iqk_convert_iq2_kl_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ3_KS : iqk_convert_iq3_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ3_K : iqk_convert_iq3_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_KSS: iqk_convert_iq4_kss_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_KS : iqk_convert_iq4_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_K : iqk_convert_iq4_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ5_KS : iqk_convert_iq5_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break;
Expand Down Expand Up @@ -4718,6 +4773,57 @@ void iqk_convert_iq2_kl_q8_k_r8(int n, const void * vx, size_t bx, void * vy, in
}
}

void iqk_convert_iq4_kss_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_iq4_kss * x8[8];

block_q8_k_r8 * y = (block_q8_k_r8 *)vy;

auto values = vld1q_s8_x2(iq4k_values);

float drow[8];
float dnew[8];
int8_t ls[16];

int8x16x2_t xv[8];
uint32_t block[8];

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) {
const float * dptr = (const float *)((const char *)vx + (ix + k)*bx);
drow[k] = dptr[0];
x8[k] = (const block_iq4_kss *)(dptr + 1);
}
auto vd = vld1q_f32_x2(drow);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
for (int ib32 = 0; ib32 < 8; ++ib32) {
auto q4 = x8[k][i].qs + 4*ib32;
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
uint8_t s8 = (s32 | (s32 >> 15)) & 0xff;
ls[2*ib32+0] = ls[2*ib32+1] = ((s8 & 254) - 127);
auto val16 = vandq_u16(vld1q_u16((const uint16_t *)q4), vdupq_n_u16(0xfffe));
auto val8 = vreinterpretq_u8_u16(veorq_u16(val16, vshrq_n_u16(val16, 1)));
auto& block_values = values.val[s8 & 1];
xv[ib32].val[0] = vqtbl1q_s8(block_values, vandq_u8(val8, vdupq_n_u8(0xf)));
xv[ib32].val[1] = vqtbl1q_s8(block_values, vshrq_n_u8(val8, 4));
}
dnew[k] = convert_to_q8_k_r8(1.f/127, xv, ls, block, (uint32_t *)y[i].qs + k);
}
auto d = vld1q_f32_x2(dnew);
d.val[0] = vmulq_f32(d.val[0], vd.val[0]);
d.val[1] = vmulq_f32(d.val[1], vd.val[1]);
vst1_f16((float16_t *)y[i].d + 0, vcvt_f16_f32(d.val[0]));
vst1_f16((float16_t *)y[i].d + 4, vcvt_f16_f32(d.val[1]));
}
y += nb;
}
}

void iqk_convert_iq4_ks_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
Expand Down Expand Up @@ -5163,6 +5269,7 @@ bool iqk_convert_iqk_quants_q80_r8(int type, int n, const void * vx, size_t bx,
case GGML_TYPE_IQ2_KL : iqk_convert_iq2_kl_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ3_KS : iqk_convert_iq3_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ3_K : iqk_convert_iq3_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_KSS: iqk_convert_iq4_kss_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_KS : iqk_convert_iq4_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_K : iqk_convert_iq4_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ5_KS : iqk_convert_iq5_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break;
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ struct MulMat {
case GGML_TYPE_IQ3_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ4_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ4_KSS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ5_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
Expand Down Expand Up @@ -301,6 +302,7 @@ struct MulMat {
case GGML_TYPE_IQ2_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_KL : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ3_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ4_KSS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ4_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ5_KS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
Expand Down
28 changes: 10 additions & 18 deletions ggml/src/iqk/iqk_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4229,25 +4229,17 @@ uint16_t prune_iq4ks(uint16_t v, const int8_t * values, const float * x, const f
q4[j] = q;
auto pc = popcount(q);
float diff0 = dl*iq4k_values[q] - x[j];
if (q > 0) {
uint8_t qm = q - 1u;
int pcm = popcount(qm);
if (pcm == pc-1 || pcm == pc+1) {
float diff1 = dl*values[qm] - x[j];
int qmin = std::max(int(q)-2, 0);
int qmax = std::min(int(q)+2, 15);
for (int iq = qmin; iq <= qmax; ++iq) {
uint8_t qq = iq;
if (qq == q) continue;
int pci = popcount(qq);
if (std::abs(pci - pc)%2) {
float diff1 = dl*values[qq] - x[j];
float score = w[j]*(diff1*diff1 - diff0*diff0);
if (score < best_score) {
best_score = score; jbest = j; bestq = qm;
}
}
}
if (q < 15) {
uint8_t qp = q + 1u;
int pcp = popcount(qp);
if (pcp == pc-1 || pcp == pc+1) {
float diff1 = dl*values[qp] - x[j];
float score = w[j]*(diff1*diff1 - diff0*diff0);
if (score < best_score) {
best_score = score; jbest = j; bestq = qp;
best_score = score; jbest = j; bestq = qq;
}
}
}
Expand Down Expand Up @@ -4468,7 +4460,7 @@ static void quantize_row_iq4_kss_impl(int n_per_row, const float * x, char * cy,
}
}
}
if (sumq2 > 0) *dptr = sumqx/sumq2;
if (sumq2 > 0) *dptr = sumqx/sumq2 * 1.01f;
}

void prune_iq4ks_to_iq4kss(int n_per_row, const uint16_t * table, const char * cx, const float * x, char *cy,
Expand Down