Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
328 changes: 318 additions & 10 deletions ggml/src/iqk/iqk_gemm_iquants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3348,20 +3348,328 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
}
}

inline float convert_to_q8_k_r8(float d0, const int8x16x2_t * qx, const int8_t * scales, uint32_t * block, uint32_t * q8_k) {
auto max_i16 = vdupq_n_u16(0);
int16x8x4_t q[8];
for (int ib32 = 0; ib32 < 8; ++ib32) {
auto scale_l = vdup_n_s8(scales[2*ib32+0]);
auto scale_h = vdup_n_s8(scales[2*ib32+1]);
q[ib32].val[0] = vmull_s8(scale_l, vget_low_s8 (qx[ib32].val[0]));
q[ib32].val[1] = vmull_s8(scale_l, vget_high_s8(qx[ib32].val[0]));
q[ib32].val[2] = vmull_s8(scale_h, vget_low_s8 (qx[ib32].val[1]));
q[ib32].val[3] = vmull_s8(scale_h, vget_high_s8(qx[ib32].val[1]));
max_i16 = vmaxq_u16(max_i16, vmaxq_u16(vabsq_s16(q[ib32].val[0]), vabsq_s16(q[ib32].val[1])));
max_i16 = vmaxq_u16(max_i16, vmaxq_u16(vabsq_s16(q[ib32].val[2]), vabsq_s16(q[ib32].val[3])));
}
uint16_t imax = vmaxvq_u16(max_i16);
if (!imax) {
for (int ib32 = 0; ib32 < 8; ++ib32) for (int l = 0; l < 8; ++l) q8_k[64*ib32 + 8*l] = 0;
return 0.f;
}
float dnew = float(imax) * d0;
//auto max_u32 = vmaxq_u32(vmovl_u16(vget_low_u16(max_i16)), vmovl_u16(vget_high_u16(max_i16)));
//auto max_f32 = vcvtq_f32_u32(max_u32);
//auto dnew = vmaxvq_f32(max_f32) * d0;
bool needs_scaling = true;
if (dnew <= 1.f) {
dnew = 1.f; needs_scaling = false;
}
auto scale = vdupq_n_f32(1/dnew);
for (int ib32 = 0; ib32 < 8; ++ib32) {
if (needs_scaling) {
for (int l = 0; l < 4; ++l) {
auto i1 = vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (q[ib32].val[l])))));
auto i2 = vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_high_s16(q[ib32].val[l])))));
q[ib32].val[l] = vcombine_s16(vmovn_s32(i1), vmovn_s32(i2));
}
}
for (int l = 0; l < 2; ++l) {
auto s8 = vcombine_s8(vmovn_s16(q[ib32].val[2*l+0]), vmovn_s16(q[ib32].val[2*l+1]));
vst1q_s8((int8_t *)block + 16*l, s8);
}
auto qb = q8_k + 64*ib32;
for (int l = 0; l < 8; ++l) {
qb[8*l] = block[l];
}
}
return dnew;
}

void iqk_convert_iq2_xxs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_iq2_xxs * x8[8];

block_q8_k_r8 * y = (block_q8_k_r8 *)vy;

int8_t ls[16];
uint32_t block[8];
uint32_t aux32[2];
const uint8_t * aux8 = (const uint8_t *)aux32;

int8x16x2_t xv[8];

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_xxs *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d);
for (int ib32 = 0; ib32 < 8; ++ib32) {
std::memcpy(aux32, x8[k][i].qs + 4*ib32, 2*sizeof(uint32_t));
ls[2*ib32+0] = ls[2*ib32+1] = (2*(aux32[1] >> 28) + 1);
xv[ib32].val[0] = vreinterpretq_s8_u64(uint64x2_t{iq2xxs_grid[aux8[0]], iq2xxs_grid[aux8[1]]});
xv[ib32].val[1] = vreinterpretq_s8_u64(uint64x2_t{iq2xxs_grid[aux8[2]], iq2xxs_grid[aux8[3]]});
apply_signs_2((uint8x16_t *)xv[ib32].val, keven_signs, aux32[1]);
}
float dnew = convert_to_q8_k_r8(1.f/124, xv, ls, block, (uint32_t *)y[i].qs + k);
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
}
}
y += nb;
}
}

void iqk_convert_iq2_xs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_iq2_xs * x8[8];

block_q8_k_r8 * y = (block_q8_k_r8 *)vy;

uint32_t block[8];

int8x16x2_t xv[8];

union { int8x16_t vec; int8_t val[16]; } helper;

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_xs *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d);
auto aux = vld1_u8(x8[k][i].scales);
auto scales_l = vand_u8(aux, vdup_n_u8(0xf));
auto scales_h = vshr_n_u8(aux, 4);
auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
helper.vec = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1)));
for (int ib32 = 0; ib32 < 8; ++ib32) {
xv[ib32].val[0] = vreinterpretq_s8_u64(uint64x2_t{iq2xs_grid[x8[k][i].qs[4*ib32+0] & 511], iq2xs_grid[x8[k][i].qs[4*ib32+1] & 511]});
xv[ib32].val[1] = vreinterpretq_s8_u64(uint64x2_t{iq2xs_grid[x8[k][i].qs[4*ib32+2] & 511], iq2xs_grid[x8[k][i].qs[4*ib32+3] & 511]});
auto s1 = vreinterpretq_s8_u64(uint64x2_t{keven_signs[x8[k][i].qs[4*ib32+0] >> 9], keven_signs[x8[k][i].qs[4*ib32+1] >> 9]});
auto s2 = vreinterpretq_s8_u64(uint64x2_t{keven_signs[x8[k][i].qs[4*ib32+2] >> 9], keven_signs[x8[k][i].qs[4*ib32+3] >> 9]});
xv[ib32].val[0] = vmulq_s8(xv[ib32].val[0], s1);
xv[ib32].val[1] = vmulq_s8(xv[ib32].val[1], s2);
}
float dnew = convert_to_q8_k_r8(1.f/124, xv, helper.val, block, (uint32_t *)y[i].qs + k);
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
}
}
y += nb;
}
}

void iqk_convert_iq2_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_iq2_s * x8[8];

block_q8_k_r8 * y = (block_q8_k_r8 *)vy;

uint32_t block[8];

union { int8x16_t vec; int8_t val[16]; } helper;
int8x16x2_t xv[8];

SignHelper sh;

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_s *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d);
auto aux = vld1_u8(x8[k][i].scales);
auto scales_l = vand_u8(aux, vdup_n_u8(0xf));
auto scales_h = vshr_n_u8(aux, 4);
auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
helper.vec = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1)));
for (int j = 0; j < 2; ++j) {
auto qs = x8[k][i].qs + 16*j;
auto qh = x8[k][i].qh + 4*j;
auto signs16 = vld1q_u8(qs + QK_K/8);
sh.init();
DequantizerIQ2S::make4(sh, signs16, qs+0, qh+0, (uint8x16_t *)&xv[4*j+0]);
DequantizerIQ2S::make4(sh, signs16, qs+8, qh+2, (uint8x16_t *)&xv[4*j+2]);
}
float dnew = convert_to_q8_k_r8(1.f/124, xv, helper.val, block, (uint32_t *)y[i].qs + k);
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
}
}
y += nb;
}
}

void iqk_convert_iq3_xxs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_iq3_xxs * x8[8];

block_q8_k_r8 * y = (block_q8_k_r8 *)vy;

int8_t ls[16];
int8x16x2_t xv[8];
uint32_t block[8];
uint32_t aux32;

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq3_xxs *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = 0.25f * GGML_FP16_TO_FP32(x8[k][i].d);
auto qs = x8[k][i].qs;
auto sas = qs + QK_K/4;
for (int ib32 = 0; ib32 < 8; ++ib32) {
std::memcpy(&aux32, sas + 4*ib32, sizeof(uint32_t));
ls[2*ib32 + 0] = ls[2*ib32 + 1] = (2*(aux32 >> 28) + 1);
xv[ib32].val[0] = vreinterpretq_s8_u32(uint32x4_t{iq3xxs_grid[qs[0]], iq3xxs_grid[qs[1]], iq3xxs_grid[qs[2]], iq3xxs_grid[qs[3]]});
xv[ib32].val[1] = vreinterpretq_s8_u32(uint32x4_t{iq3xxs_grid[qs[4]], iq3xxs_grid[qs[5]], iq3xxs_grid[qs[6]], iq3xxs_grid[qs[7]]});
apply_signs_2((uint8x16_t *)xv[ib32].val, keven_signs, aux32);
qs += 8;
}
float dnew = convert_to_q8_k_r8(1.f/124, xv, ls, block, (uint32_t *)y[i].qs + k);
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
}
}
y += nb;
}
}

//struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
// DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
//
// constexpr static int num_blocks() { return 8; }
// constexpr static bool should_scale_quants() { return false; }
//
// template <typename Q8>
// inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
// d = GGML_FP16_TO_FP32(x[i].d);
// uint32_t scales32[2];
// std::memcpy(scales32, x[i].scales, 4);
// scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
// scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
// auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7
// scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400)));
// auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8));
// int32x4x2_t scales;
// scales.val[0] = vmovl_s16(vget_low_s16(scales16));
// scales.val[1] = vmovl_s16(vget_high_s16(scales16));
// return scales;
// }
//
// static inline void make2(SignHelper& sh, const uint8x16_t& signs16, const uint16x8_t& idx_l, uint8_t qh,
// const int8x16_t& hshift, uint8x16_t * b) {
// auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256)));
// const uint16_t * idx = (const uint16_t *)&vindex;
// b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});
// b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});
// sh.apply_signs_1(b+0, signs16);
// sh.apply_signs_1(b+1, signs16);
// }
// static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh,
// const int8x16_t& hshift, uint8x16_t * b) {
// auto idx_l = vld1q_u8(qs);
// make2(sh, signs16, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0);
// make2(sh, signs16, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2);
// }
//
// inline void prepare(int i, int j) {
//
// static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
// const auto hshift = vld1q_s16(k_shift);
//
// const auto * qs = x[i].qs + 32*j;
// const auto * qh = x[i].qh + 4*j;
// const auto signs16 = vld1q_u8(x[i].signs + 16*j);
//
// sh.init();
// make4(sh, signs16, qs+ 0, qh+0, hshift, bits.b1.val);
// make4(sh, signs16, qs+16, qh+2, hshift, bits.b2.val);
// }
//
// SimpleBits bits;
// SignHelper sh;
// uint32x4x2_t gas;
//
//};

void iqk_convert_iq3_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_iq3_s * x8[8];

block_q8_k_r8 * y = (block_q8_k_r8 *)vy;

int8_t ls[16];
SignHelper sh;

uint32_t block[8];
int8x16x2_t xv[8];

static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
const auto hshift = vld1q_s16(k_shift);

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq3_s *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = GGML_FP16_TO_FP32(x8[k][i].d);
for (int j = 0; j < 2; ++j) {
const auto * qs = x8[k][i].qs + 32*j;
const auto * qh = x8[k][i].qh + 4*j;
const auto signs16 = vld1q_u8(x8[k][i].signs + 16*j);
sh.init();
DequantizerIQ3S::make4(sh, signs16, qs+ 0, qh+0, hshift, (uint8x16_t *)&xv[4*j+0]);
DequantizerIQ3S::make4(sh, signs16, qs+16, qh+2, hshift, (uint8x16_t *)&xv[4*j+2]);
}
for (int ib32 = 0; ib32 < 8; ++ib32) {
ls[2*ib32 + 0] = ls[2*ib32 + 1] = (2*((x8[k][i].scales[ib32/2] >> 4*(ib32%2)) & 0xf) + 1);
}
float dnew = convert_to_q8_k_r8(1.f/127, xv, ls, block, (uint32_t *)y[i].qs + k);
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
}
}
y += nb;
}
}


}

bool iqk_convert_iquants_q80_r8([[maybe_unused]] int type, int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, int nrc_x) {
if (n%QK_K != 0 || nrc_x%8 != 0) return false;
return false;
//switch (ggml_type(type)) {
// case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_IQ2_XS : iqk_convert_iq2_xs_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_IQ2_S : iqk_convert_iq2_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_IQ3_XXS: iqk_convert_iq3_xxs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_IQ3_S : iqk_convert_iq3_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
// default: return false;
//}
//return true;
switch (ggml_type(type)) {
case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ2_XS : iqk_convert_iq2_xs_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ2_S : iqk_convert_iq2_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ3_XXS: iqk_convert_iq3_xxs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ3_S : iqk_convert_iq3_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;
}

bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,11 @@ struct MulMat {
}
#else
switch (type) {
case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_Q4_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q4_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q5_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
Expand Down