Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 156 additions & 2 deletions ggml/src/iqk/iqk_gemm_1bit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2515,6 +2515,154 @@ void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info,
}
}
}

inline float convert_to_q8_k_r8(float d0, const int8x16x2_t * qx, const int8_t * scales, uint32_t * block, uint32_t * q8_k) {
auto max_i16 = vdupq_n_u16(0);
int16x8x4_t q[8];
for (int ib32 = 0; ib32 < 8; ++ib32) {
auto scale_l = vdup_n_s8(scales[2*ib32+0]);
auto scale_h = vdup_n_s8(scales[2*ib32+1]);
q[ib32].val[0] = vmull_s8(scale_l, vget_low_s8 (qx[ib32].val[0]));
q[ib32].val[1] = vmull_s8(scale_l, vget_high_s8(qx[ib32].val[0]));
q[ib32].val[2] = vmull_s8(scale_h, vget_low_s8 (qx[ib32].val[1]));
q[ib32].val[3] = vmull_s8(scale_h, vget_high_s8(qx[ib32].val[1]));
max_i16 = vmaxq_u16(max_i16, vmaxq_u16(vabsq_s16(q[ib32].val[0]), vabsq_s16(q[ib32].val[1])));
max_i16 = vmaxq_u16(max_i16, vmaxq_u16(vabsq_s16(q[ib32].val[2]), vabsq_s16(q[ib32].val[3])));
}
uint16_t imax = vmaxvq_u16(max_i16);
if (!imax) {
for (int ib32 = 0; ib32 < 8; ++ib32) for (int l = 0; l < 8; ++l) q8_k[64*ib32 + 8*l] = 0;
return 0.f;
}
float dnew = float(imax) * d0;
//auto max_u32 = vmaxq_u32(vmovl_u16(vget_low_u16(max_i16)), vmovl_u16(vget_high_u16(max_i16)));
//auto max_f32 = vcvtq_f32_u32(max_u32);
//auto dnew = vmaxvq_f32(max_f32) * d0;
bool needs_scaling = true;
if (dnew <= 1.f) {
dnew = 1.f; needs_scaling = false;
}
auto scale = vdupq_n_f32(1/dnew);
for (int ib32 = 0; ib32 < 8; ++ib32) {
if (needs_scaling) {
for (int l = 0; l < 4; ++l) {
auto i1 = vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (q[ib32].val[l])))));
auto i2 = vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_high_s16(q[ib32].val[l])))));
q[ib32].val[l] = vcombine_s16(vmovn_s32(i1), vmovn_s32(i2));
}
}
for (int l = 0; l < 2; ++l) {
auto s8 = vcombine_s8(vmovn_s16(q[ib32].val[2*l+0]), vmovn_s16(q[ib32].val[2*l+1]));
vst1q_s8((int8_t *)block + 16*l, s8);
}
auto qb = q8_k + 64*ib32;
for (int l = 0; l < 8; ++l) {
qb[8*l] = block[l];
}
}
return dnew;
}

void iqk_convert_iq1_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_iq1_s * x8[8];

block_q8_k_r8 * y = (block_q8_k_r8 *)vy;

int8_t ls[16];

uint32_t block[8];

int8x16x2_t qx[8];

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq1_s *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d);
auto qs = x8[k][i].qs;
auto qh = x8[k][i].qh;
int8x16x2_t value;
for (int ib32 = 0; ib32 < 8; ++ib32) {
ls[2*ib32 + 0] = ls[2*ib32 + 1] = (2*((qh[ib32] >> 12) & 7) + 1);
value.val[0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[0] | ((qh[ib32] << 8) & 0x700)], iq1s_grid[qs[1] | ((qh[ib32] << 5) & 0x700)]});
value.val[1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[2] | ((qh[ib32] << 2) & 0x700)], iq1s_grid[qs[3] | ((qh[ib32] >> 1) & 0x700)]});
value.val[0] = vshlq_n_s8(vaddq_s8(value.val[0], vdupq_n_s8(1)), 3);
value.val[1] = vshlq_n_s8(vaddq_s8(value.val[1], vdupq_n_s8(1)), 3);
auto delta = vdupq_n_s8(qh[ib32] & 0x8000 ? -9 : -7);
qx[ib32].val[0] = vaddq_s8(value.val[0], delta);
qx[ib32].val[1] = vaddq_s8(value.val[1], delta);
qs += 4;
}
float dnew = convert_to_q8_k_r8(1.f/126, qx, ls, block, (uint32_t *)y[i].qs + k);
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
}
}
y += nb;
}
}

void iqk_convert_iq1_m_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_iq1_m * x8[8];

block_q8_k_r8 * y = (block_q8_k_r8 *)vy;

int8_t ls[16];

uint32_t block[8];

int8x16x2_t qx[8];

uint32x4x2_t mask = {uint32x4_t{0x00000008, 0x00000008, 0x00000080, 0x00000080}, uint32x4_t {0x00080000, 0x00080000, 0x00800000, 0x00800000}};

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq1_m *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
const uint16_t * sc = (const uint16_t *)x8[k][i].scales;
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
float d = 0.125f * GGML_FP16_TO_FP32(scale.f16);
auto qs = x8[k][i].qs;
auto qh = x8[k][i].qh;
int8x16x2_t value;
for (int ib32 = 0; ib32 < 8; ++ib32) {
ls[2*ib32 + 0] = (2*((sc[ib32/2] >> (6*(ib32%2)+0)) & 0x7) + 1);
ls[2*ib32 + 1] = (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1);
//value = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | ((qh[1] << 8) & 0x700)],
// iq1s_grid[qs[1] | ((qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | ((qh[0] << 8) & 0x700)]);
value.val[0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[0] | ((qh[0] << 8) & 0x700)], iq1s_grid[qs[1] | ((qh[0] << 4) & 0x700)]});
value.val[1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[2] | ((qh[1] << 8) & 0x700)], iq1s_grid[qs[3] | ((qh[1] << 4) & 0x700)]});
value.val[0] = vshlq_n_s8(vaddq_s8(value.val[0], vdupq_n_s8(1)), 3);
value.val[1] = vshlq_n_s8(vaddq_s8(value.val[1], vdupq_n_s8(1)), 3);

auto aux = vdupq_n_u32(qh[0] | qh[1] << 16);
uint32x4x2_t delta_mask{ vceqq_u32(vandq_u32(aux, mask.val[0]), mask.val[0]), vceqq_u32(vandq_u32(aux, mask.val[1]), mask.val[1]) };
uint8x16x2_t delta{ vaddq_s8(vdupq_n_s8(7), vandq_s8(vdupq_n_s8(2), vreinterpretq_s8_u32(delta_mask.val[0]))),
vaddq_s8(vdupq_n_s8(7), vandq_s8(vdupq_n_s8(2), vreinterpretq_s8_u32(delta_mask.val[1]))) };
qx[ib32].val[0] = vsubq_s8(value.val[0], delta.val[0]);
qx[ib32].val[1] = vsubq_s8(value.val[1], delta.val[1]);

qs += 4;
qh += 2;
}
float dnew = convert_to_q8_k_r8(1.f/126, qx, ls, block, (uint32_t *)y[i].qs + k);
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
}
}
y += nb;
}
}

}

bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& funcs, mul_mat_t& func16) {
Expand Down Expand Up @@ -2573,8 +2721,14 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,

}

bool iqk_convert_1bit_q80_r8([[maybe_unused]] int type, [[maybe_unused]] int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, [[maybe_unused]] int nrc_x) {
return false;
bool iqk_convert_1bit_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
if (n%QK_K != 0 || nrc_x%8 != 0) return false;
switch (ggml_type(type)) {
case GGML_TYPE_IQ1_S: iqk_convert_iq1_s_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ1_M: iqk_convert_iq1_m_q8_k_r8(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;
}

#endif
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ struct MulMat {
case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q6_K : return nrc_y >= 64 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ1_M : return nrc_y >= 8 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
case GGML_TYPE_IQ2_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
Expand Down