Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -6598,31 +6598,31 @@ void kernel_mul_mv_iq2_k_f32_impl(

struct Trellis3 {
constexpr constant static uint32_t kmask = 0x3f3f3f3f;
constexpr constant static uint32_t ka = 89226354;
constexpr constant static uint32_t kb = 64248484;
constexpr constant static uint32_t ka = 0xCBAC1FED;
constexpr constant static uint32_t ka1 = ka*ka;
constexpr constant static uint32_t kb1 = kb*ka+kb;
constexpr constant static uint32_t ka2 = ka1*ka;
constexpr constant static uint32_t kb2 = kb1*ka+kb;
constexpr constant static uint32_t ka3 = ka2*ka;
constexpr constant static uint32_t kb3 = kb2*ka+kb;
static inline char4 gen4(uint32_t val) {
thread uint32_t aux[4] = {(ka*val + kb) & kmask, (ka1*val + kb1) & kmask, (ka2*val + kb2) & kmask, (ka3*val + kb3) & kmask};
thread uint32_t aux[4] = {(ka*val) & kmask, (ka1*val) & kmask, (ka2*val) & kmask, (ka3*val) & kmask};
thread const int8_t * a8 = (thread const int8_t *)aux;
char4 result;
for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3];
return result;
}
template <typename T4>
static inline void gen8(uint32_t val, thread T4& v1, thread T4& v2) {
thread uint32_t aux[4] = {ka*val + kb, ka1*val + kb1, ka2*val + kb2, ka3*val + kb3};
thread uint32_t aux[4] = {ka*val, ka1*val, ka2*val, ka3*val};
uint32_t aux32[2];
thread const int8_t * a8 = (thread const int8_t *)aux32;
//thread const char4 * a8 = (thread const char4 *)aux32;
for (int i = 0; i < 4; ++i) {
aux32[0] = aux[i] & kmask;
aux32[1] = (ka3*aux[i] + kb3) & kmask;
aux32[1] = (ka3*aux[i]) & kmask;
v1[i] = -126 + a8[0] + a8[1] + a8[2] + a8[3];
v2[i] = -126 + a8[4] + a8[5] + a8[6] + a8[7];
// Much slower:
//v1[i] = -126 + a8[0][0] + a8[0][1] + a8[0][2] + a8[0][3];
//v2[i] = -126 + a8[1][0] + a8[1][1] + a8[1][2] + a8[1][3];
}
}
};
Expand Down Expand Up @@ -6837,7 +6837,7 @@ void kernel_mul_mv_iq3_kt_f32_impl(
float drow[N_DST];
for (int row = 0; row < N_DST; ++row) {
device const float * dptr = (device const float *)(cx + row*row_size);
drow[row] = dptr[0] * 31.75f * 1.01f;
drow[row] = dptr[0] * 1.01f;
}

device const block_iq3_kt * x = (device const block_iq3_kt *)(cx + sizeof(float));
Expand All @@ -6854,15 +6854,15 @@ void kernel_mul_mv_iq3_kt_f32_impl(
const float ls = drow[row] * ((sc[(it/2)%4] >> 4*(it/8)) & 0xf);
const uint8_t mask = 1 << (it/2);

Trellis::gen8(q2[2*it+0]+4096, v[0], v[1]);
Trellis3::gen8(q2[2*it+0]+4096, v[0], v[1]);
for (int j = 0; j < 8; ++j) {
u32[j] &= 0x7fffffff;
u32[j] |= qh[j+0] & mask ? 0x80000000 : 0;
}

auto sum = v[0]*y4[0] + v[1]*y4[1];

Trellis::gen8(q2[2*it+1]+4096, v[0], v[1]);
Trellis3::gen8(q2[2*it+1]+4096, v[0], v[1]);
for (int j = 0; j < 8; ++j) {
u32[j] &= 0x7fffffff;
u32[j] |= qh[j+8] & mask ? 0x80000000 : 0;
Expand Down Expand Up @@ -8593,17 +8593,14 @@ template <typename type4x4>
void dequantize_iq3_kt(device const block_iq3_kt * x, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256
int ib32 = il/2;
half scale = (half)((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf) * 31.75h * 1.01h;
half scale = (half)((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf) * 1.01h;
device const uint16_t * q2 = (device const uint16_t *)x->ql + 4*ib32 + 2*(il%2);
device const uint8_t * qh = x->qh + 16*(il%2);
const uint8_t mask = 1 << ib32;

half4 v1, v2;
for (int i = 0; i < 2; ++i) {
Trellis::gen8(q2[i]+4096, v1, v2);
//v1 *= scale; v2 *= scale;
//for (int j = 0; j < 4; ++j) reg[2*i+0][j] = qh[8*i+0+j] & mask ? -abs(v1[j]) : abs(v1[j]);
//for (int j = 0; j < 4; ++j) reg[2*i+1][j] = qh[8*i+4+j] & mask ? -abs(v2[j]) : abs(v2[j]);
Trellis3::gen8(q2[i]+4096, v1, v2);
v1 = abs(v1)*scale; v2 = abs(v2)*scale;
for (int j = 0; j < 4; ++j) reg[2*i+0][j] = qh[8*i+0+j] & mask ? -v1[j] : v1[j];
for (int j = 0; j < 4; ++j) reg[2*i+1][j] = qh[8*i+4+j] & mask ? -v2[j] : v2[j];
Expand Down
249 changes: 235 additions & 14 deletions ggml/src/iqk/iqk_gemm_legacy_quants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2782,29 +2782,247 @@ void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
}
}

typedef struct {
ggml_half d[16];
int8_t qs[256];
} block_q8_1_r8;

template <int nrc_y>
void mul_mat_q8_1_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_1_x4> q8(info);
int nb = n / QK8_0;
float32x4_t acc[2*nrc_y] = {};
int8x16_t qx[16];
float d8[8*nrc_y];
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q8_1_r8 * iq8 = (const block_q8_1_r8 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
vst1q_f32(d8+8*iy+0, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d+0)));
vst1q_f32(d8+8*iy+4, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d+4)));
}
for (int k = 0; k < 4; ++k) {
auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d);
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
auto m16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d+8);
auto m1 = vcvt_f32_f16(vget_low_f16 (m16));
auto m2 = vcvt_f32_f16(vget_high_f16(m16));
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
int32x4_t sumi1, sumi2;
for (int iy = 0; iy < nrc_y; ++iy) {
qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2);
auto dy = vdupq_n_f32(d8[8*iy+k]);
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
auto my = vdupq_n_f32(d8[8*iy+k+4]);
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], m1, my);
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], m2, my);
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d);
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
auto m16 = vld1q_f16((const float16_t *)iq8[ib].d+8);
auto m1 = vcvt_f32_f16(vget_low_f16 (m16));
auto m2 = vcvt_f32_f16(vget_high_f16(m16));
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j);
int32x4_t sumi1, sumi2;
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2);
auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
auto my = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].s));
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], m1, my);
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], m2, my);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, acc[2*iy+0]);
info.store(ix+4, iy, acc[2*iy+1]);
acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f);
}
}
}

bool iqk_convert_legacy_quants_q8_r8([[maybe_unused]] int type, [[maybe_unused]] int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, [[maybe_unused]] int nrc_x) {
return false;
//switch (type) {
// case GGML_TYPE_Q4_0 : iqk_convert_qX_q80_r8<block_q4_0, Q4_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q4_1 : iqk_convert_qX_1_q8_1_r8<block_q4_1, Q4_1_Dequantizer>(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8<block_q5_0, Q5_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8<block_q5_1, Q5_1_Dequantizer<block_q5_1>>(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, Q6_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, IQ4_NL0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q8_0 : iqk_convert_q80_q80_r8(n, vx, bx, vy, nrc_x); break;
// default: return false;
//}
//return true;
struct DeqQ40 {
const int8x16_t m8 = vdupq_n_s8(-8);
const uint8x16_t ml = vdupq_n_s8(0xf);
inline int8x16x2_t dequant(const block_q4_0& x) const {
auto bits = vld1q_u8(x.qs);
return { vaddq_s8(vreinterpretq_s8_u8(vandq_u8(bits, ml)), m8), vaddq_s8(vreinterpretq_s8_u8(vshrq_n_u8(bits, 4)), m8) };
}
};

struct DeqQ41 {
const uint8x16_t ml = vdupq_n_s8(0xf);
inline int8x16x2_t dequant(const block_q4_1& x) const {
auto bits = vld1q_u8(x.qs);
return { vreinterpretq_s8_u8(vandq_u8(bits, ml)), vreinterpretq_s8_u8(vshrq_n_u8(bits, 4)) };
}
};

struct DeqIQ4NL {
const int8x16_t mt = load_values();
const uint8x16_t ml = vdupq_n_s8(0xf);
inline int8x16x2_t dequant(const block_iq4_nl& x) const {
auto bits = vld1q_u8(x.qs);
return { vqtbl1q_s8(mt, vandq_u8(bits, ml)), vqtbl1q_s8(mt, vshrq_n_u8(bits, 4)) };
}
static inline int8x16_t load_values() { return vld1q_s8(iq4k_values); }
};

struct DeqQ50 {

inline int8x16x2_t dequant(const block_q5_0& x) const {
int8x16x2_t r;
bits.prepare1(x.qs, r.val);
auto qh = x.qh;
r.val[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0))));
r.val[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2))));
return r;
}

Q4LegacyBits bits;
HighBit5Legacy hbits;
const uint8x16_t mh = vdupq_n_u8(0xf0);
};

struct DeqQ51 {

inline int8x16x2_t dequant(const block_q5_1& x) const {
int8x16x2_t r;
bits.prepare1(x.qs, r.val);
auto qh = x.qh;
r.val[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[0]), vandq_u8(mh, hbits.to_bytes(qh+0))));
r.val[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[1]), vandq_u8(mh, hbits.to_bytes(qh+2))));
return r;
}

Q4LegacyBits bits;
HighBit5Legacy hbits;
const uint8x16_t mh = vdupq_n_u8(0x10);
};

struct DeqQ60 {

inline int8x16x2_t dequant(const block_q6_0& x) const {
int8x16x2_t r;
bits.prepare1(x.qs, r.val);
auto qh8 = vld1_u8(x.qh);
auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8);
r.val[0] = vaddq_s8(vorrq_u8(r.val[0], vandq_u8(qh, hmask)), m32);
r.val[1] = vaddq_s8(vorrq_u8(r.val[1], vandq_u8(vshrq_n_u8(qh, 2), hmask)), m32);
return r;
}

Q4LegacyBits bits;
const int8x16_t m32 = vdupq_n_s8(-32);
const uint8x16_t hmask = vdupq_n_u8(0x30);
};

struct DeqQ80 {
inline int8x16x2_t dequant(const block_q8_0& x) const {
return vld1q_s8_x2(x.qs);
}
};

template <typename Block, typename Dequantizer>
void iqk_convert_qX_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK4_0 == 0);
GGML_ASSERT(nrc_x%8 == 0);

const int nb = n/QK8_0;

block_q8_0_r8 * y = (block_q8_0_r8 *)vy;

const Block * x8[8];

uint32_t block[8];

Dequantizer deq;

for (int ix = 0; ix < nrc_x; ix += 8) {

for (int k = 0; k < 8; ++k) x8[k] = (const Block *)((const char *)vx + (ix + k)*bx);

for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
y[i].d[k] = x8[k][i].d;
vst1q_s8_x2((int8_t *)block, deq.dequant(x8[k][i]));
auto qs = (uint32_t *)y[i].qs;
for (int l = 0; l < 4; ++l) {
qs[8*l + k + 0] = block[l + 0];
qs[8*l + k + 32] = block[l + 4];
}
}
}
y += nb;
}
}

template <typename Block, typename Dequantizer>
void iqk_convert_qX_1_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK4_0 == 0);
GGML_ASSERT(nrc_x%8 == 0);

const int nb = n/QK8_0;

block_q8_1_r8 * y = (block_q8_1_r8 *)vy;

const Block * x8[8];

uint32_t block[8];

Dequantizer deq;

for (int ix = 0; ix < nrc_x; ix += 8) {

for (int k = 0; k < 8; ++k) x8[k] = (const Block *)((const char *)vx + (ix + k)*bx);

for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
y[i].d[k+0] = x8[k][i].d;
y[i].d[k+8] = x8[k][i].m;
vst1q_s8_x2((int8_t *)block, deq.dequant(x8[k][i]));
auto qs = (uint32_t *)y[i].qs;
for (int l = 0; l < 4; ++l) {
qs[8*l + k + 0] = block[l + 0];
qs[8*l + k + 32] = block[l + 4];
}
}
}
y += nb;
}
}

}

bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
switch (type) {
case GGML_TYPE_Q4_0 : iqk_convert_qX_q80_r8<block_q4_0, DeqQ40>(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q4_1 : iqk_convert_qX_1_q8_1_r8<block_q4_1, DeqQ41>(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8<block_q5_0, DeqQ50>(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8<block_q5_1, DeqQ51>(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, DeqQ60>(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, DeqIQ4NL>(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q8_0 : iqk_convert_qX_q80_r8<block_q8_0, DeqQ80>(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;
}

bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {

if (ne00%QK8_0 != 0) return false;

auto etypeA = ggml_type(typeA);
auto expected_typeB = etypeA == GGML_TYPE_Q4_1 || etypeA == GGML_TYPE_Q5_1 ? GGML_TYPE_Q8_1_X4 : GGML_TYPE_Q8_0_X4;
auto expected_typeB = etypeA == GGML_TYPE_Q4_1 || etypeA == GGML_TYPE_Q5_1 || etypeA == GGML_TYPE_Q8_1 ? GGML_TYPE_Q8_1_X4 : GGML_TYPE_Q8_0_X4;
if (ggml_type(typeB) != expected_typeB) return false;

func16 = nullptr;
Expand Down Expand Up @@ -2843,6 +3061,9 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
case GGML_TYPE_Q8_0_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_0_r8_q8_0, kernels);
break;
case GGML_TYPE_Q8_1:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_1_r8_q8_1, kernels);
break;
case GGML_TYPE_IQ4_NL_R4:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer, kernels);
break;
Expand Down
14 changes: 11 additions & 3 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,16 @@ struct MulMat {
}
#else
switch (type) {
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q4_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q4_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q5_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q5_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
default: break;
}
#endif
Expand Down Expand Up @@ -913,6 +920,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_Q8_0_R8:
case GGML_TYPE_Q8_1:
case GGML_TYPE_IQ4_NL_R4:
return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ1_BN:
Expand Down
2 changes: 1 addition & 1 deletion src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18722,7 +18722,7 @@ static std::pair<ggml_type, int> interleaved_properties(ggml_type type) {
{ GGML_TYPE_IQ5_KS_R4, { GGML_TYPE_IQ5_KS, 4} },
{ GGML_TYPE_IQ5_K_R4, { GGML_TYPE_IQ5_K, 4} },
{ GGML_TYPE_Q8_KV_R8, { GGML_TYPE_Q8_KV, 8} },
{ GGML_TYPE_Q8_K_R8, { GGML_TYPE_Q8_K, 8} },
{ GGML_TYPE_Q8_K_R8, { GGML_TYPE_Q8_0, 8} },
{ GGML_TYPE_BF16_R16, { GGML_TYPE_BF16, 16} },
};
if (auto it = k_map.find(type); it != k_map.end()) return it->second;
Expand Down