Skip to content

Commit 4f97409

Browse files
ikawrakowIwan Kawrakow
andauthored
Faster ARM_NEON GEMM implementation for legacy quants (#546)
* iq2_kt and iq3_kt work with new int trellis Much slower than the fp16 based trellis. I guess, Apple doesn't have int8_t SIMD on the M2-Max GPU. * q4_0 83.6 t/s -> 128.4 t/s. q4_0_r8 is at 123.5 t/s * q5_0 74.2 t/s -> 128.5 t/s. q5_0_r4 is at 111.4 t/s. * q6_0 74.2 t/s -> 128.8 t/s. q6_0_r4 is at 107.2 t/s. * q8_0 84.5 -> 128.7 t/s. q8_0_r8 is at 131 t/s. * iq4_nl 84.5 t/s -> 128.1 t/s. iq4_nl_r4 is at 120.4 t/s * q4_1 74.4 -> 115.4 t/s. There is no repacked variant * q5_1 64.2 t/s -> 114.9 t/s. There is no repacked variant. --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent a98b767 commit 4f97409

File tree

4 files changed

+260
-34
lines changed

4 files changed

+260
-34
lines changed

ggml/src/ggml-metal.metal

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6598,31 +6598,31 @@ void kernel_mul_mv_iq2_k_f32_impl(
65986598

65996599
struct Trellis3 {
66006600
constexpr constant static uint32_t kmask = 0x3f3f3f3f;
6601-
constexpr constant static uint32_t ka = 89226354;
6602-
constexpr constant static uint32_t kb = 64248484;
6601+
constexpr constant static uint32_t ka = 0xCBAC1FED;
66036602
constexpr constant static uint32_t ka1 = ka*ka;
6604-
constexpr constant static uint32_t kb1 = kb*ka+kb;
66056603
constexpr constant static uint32_t ka2 = ka1*ka;
6606-
constexpr constant static uint32_t kb2 = kb1*ka+kb;
66076604
constexpr constant static uint32_t ka3 = ka2*ka;
6608-
constexpr constant static uint32_t kb3 = kb2*ka+kb;
66096605
static inline char4 gen4(uint32_t val) {
6610-
thread uint32_t aux[4] = {(ka*val + kb) & kmask, (ka1*val + kb1) & kmask, (ka2*val + kb2) & kmask, (ka3*val + kb3) & kmask};
6606+
thread uint32_t aux[4] = {(ka*val) & kmask, (ka1*val) & kmask, (ka2*val) & kmask, (ka3*val) & kmask};
66116607
thread const int8_t * a8 = (thread const int8_t *)aux;
66126608
char4 result;
66136609
for (int i = 0; i < 4; ++i) result[i] = -126 + a8[4*i+0] + a8[4*i+1] + a8[4*i+2] + a8[4*i+3];
66146610
return result;
66156611
}
66166612
template <typename T4>
66176613
static inline void gen8(uint32_t val, thread T4& v1, thread T4& v2) {
6618-
thread uint32_t aux[4] = {ka*val + kb, ka1*val + kb1, ka2*val + kb2, ka3*val + kb3};
6614+
thread uint32_t aux[4] = {ka*val, ka1*val, ka2*val, ka3*val};
66196615
uint32_t aux32[2];
66206616
thread const int8_t * a8 = (thread const int8_t *)aux32;
6617+
//thread const char4 * a8 = (thread const char4 *)aux32;
66216618
for (int i = 0; i < 4; ++i) {
66226619
aux32[0] = aux[i] & kmask;
6623-
aux32[1] = (ka3*aux[i] + kb3) & kmask;
6620+
aux32[1] = (ka3*aux[i]) & kmask;
66246621
v1[i] = -126 + a8[0] + a8[1] + a8[2] + a8[3];
66256622
v2[i] = -126 + a8[4] + a8[5] + a8[6] + a8[7];
6623+
// Much slower:
6624+
//v1[i] = -126 + a8[0][0] + a8[0][1] + a8[0][2] + a8[0][3];
6625+
//v2[i] = -126 + a8[1][0] + a8[1][1] + a8[1][2] + a8[1][3];
66266626
}
66276627
}
66286628
};
@@ -6837,7 +6837,7 @@ void kernel_mul_mv_iq3_kt_f32_impl(
68376837
float drow[N_DST];
68386838
for (int row = 0; row < N_DST; ++row) {
68396839
device const float * dptr = (device const float *)(cx + row*row_size);
6840-
drow[row] = dptr[0] * 31.75f * 1.01f;
6840+
drow[row] = dptr[0] * 1.01f;
68416841
}
68426842

68436843
device const block_iq3_kt * x = (device const block_iq3_kt *)(cx + sizeof(float));
@@ -6854,15 +6854,15 @@ void kernel_mul_mv_iq3_kt_f32_impl(
68546854
const float ls = drow[row] * ((sc[(it/2)%4] >> 4*(it/8)) & 0xf);
68556855
const uint8_t mask = 1 << (it/2);
68566856

6857-
Trellis::gen8(q2[2*it+0]+4096, v[0], v[1]);
6857+
Trellis3::gen8(q2[2*it+0]+4096, v[0], v[1]);
68586858
for (int j = 0; j < 8; ++j) {
68596859
u32[j] &= 0x7fffffff;
68606860
u32[j] |= qh[j+0] & mask ? 0x80000000 : 0;
68616861
}
68626862

68636863
auto sum = v[0]*y4[0] + v[1]*y4[1];
68646864

6865-
Trellis::gen8(q2[2*it+1]+4096, v[0], v[1]);
6865+
Trellis3::gen8(q2[2*it+1]+4096, v[0], v[1]);
68666866
for (int j = 0; j < 8; ++j) {
68676867
u32[j] &= 0x7fffffff;
68686868
u32[j] |= qh[j+8] & mask ? 0x80000000 : 0;
@@ -8593,17 +8593,14 @@ template <typename type4x4>
85938593
void dequantize_iq3_kt(device const block_iq3_kt * x, short il, thread type4x4 & reg) {
85948594
// il is 0...15 for QK_K = 256
85958595
int ib32 = il/2;
8596-
half scale = (half)((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf) * 31.75h * 1.01h;
8596+
half scale = (half)((x->scales[ib32%4] >> 4*(ib32/4)) & 0xf) * 1.01h;
85978597
device const uint16_t * q2 = (device const uint16_t *)x->ql + 4*ib32 + 2*(il%2);
85988598
device const uint8_t * qh = x->qh + 16*(il%2);
85998599
const uint8_t mask = 1 << ib32;
86008600

86018601
half4 v1, v2;
86028602
for (int i = 0; i < 2; ++i) {
8603-
Trellis::gen8(q2[i]+4096, v1, v2);
8604-
//v1 *= scale; v2 *= scale;
8605-
//for (int j = 0; j < 4; ++j) reg[2*i+0][j] = qh[8*i+0+j] & mask ? -abs(v1[j]) : abs(v1[j]);
8606-
//for (int j = 0; j < 4; ++j) reg[2*i+1][j] = qh[8*i+4+j] & mask ? -abs(v2[j]) : abs(v2[j]);
8603+
Trellis3::gen8(q2[i]+4096, v1, v2);
86078604
v1 = abs(v1)*scale; v2 = abs(v2)*scale;
86088605
for (int j = 0; j < 4; ++j) reg[2*i+0][j] = qh[8*i+0+j] & mask ? -v1[j] : v1[j];
86098606
for (int j = 0; j < 4; ++j) reg[2*i+1][j] = qh[8*i+4+j] & mask ? -v2[j] : v2[j];

ggml/src/iqk/iqk_gemm_legacy_quants.cpp

Lines changed: 235 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2782,29 +2782,247 @@ void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
27822782
}
27832783
}
27842784

2785+
typedef struct {
2786+
ggml_half d[16];
2787+
int8_t qs[256];
2788+
} block_q8_1_r8;
2789+
2790+
template <int nrc_y>
2791+
void mul_mat_q8_1_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
2792+
GGML_ASSERT(nrc_x%8 == 0);
2793+
Q8<nrc_y, block_q8_1_x4> q8(info);
2794+
int nb = n / QK8_0;
2795+
float32x4_t acc[2*nrc_y] = {};
2796+
int8x16_t qx[16];
2797+
float d8[8*nrc_y];
2798+
for (int ix = 0; ix < nrc_x; ix += 8) {
2799+
const block_q8_1_r8 * iq8 = (const block_q8_1_r8 *)((const char *)vx + ix*bx);
2800+
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
2801+
for (int iy = 0; iy < nrc_y; ++iy) {
2802+
vst1q_f32(d8+8*iy+0, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d+0)));
2803+
vst1q_f32(d8+8*iy+4, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d+4)));
2804+
}
2805+
for (int k = 0; k < 4; ++k) {
2806+
auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d);
2807+
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
2808+
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
2809+
auto m16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d+8);
2810+
auto m1 = vcvt_f32_f16(vget_low_f16 (m16));
2811+
auto m2 = vcvt_f32_f16(vget_high_f16(m16));
2812+
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
2813+
int32x4_t sumi1, sumi2;
2814+
for (int iy = 0; iy < nrc_y; ++iy) {
2815+
qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2);
2816+
auto dy = vdupq_n_f32(d8[8*iy+k]);
2817+
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
2818+
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
2819+
auto my = vdupq_n_f32(d8[8*iy+k+4]);
2820+
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], m1, my);
2821+
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], m2, my);
2822+
}
2823+
}
2824+
}
2825+
for (int ib = 4*(nb/4); ib < nb; ++ib) {
2826+
auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d);
2827+
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
2828+
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
2829+
auto m16 = vld1q_f16((const float16_t *)iq8[ib].d+8);
2830+
auto m1 = vcvt_f32_f16(vget_low_f16 (m16));
2831+
auto m2 = vcvt_f32_f16(vget_high_f16(m16));
2832+
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j);
2833+
int32x4_t sumi1, sumi2;
2834+
for (int iy = 0; iy < nrc_y; ++iy) {
2835+
auto qy = (const block_q8_1 *)q8.y[iy];
2836+
qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2);
2837+
auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
2838+
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
2839+
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
2840+
auto my = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].s));
2841+
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], m1, my);
2842+
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], m2, my);
2843+
}
2844+
}
2845+
for (int iy = 0; iy < nrc_y; ++iy) {
2846+
info.store(ix+0, iy, acc[2*iy+0]);
2847+
info.store(ix+4, iy, acc[2*iy+1]);
2848+
acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f);
2849+
}
2850+
}
27852851
}
27862852

2787-
bool iqk_convert_legacy_quants_q8_r8([[maybe_unused]] int type, [[maybe_unused]] int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, [[maybe_unused]] int nrc_x) {
2788-
return false;
2789-
//switch (type) {
2790-
// case GGML_TYPE_Q4_0 : iqk_convert_qX_q80_r8<block_q4_0, Q4_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
2791-
// case GGML_TYPE_Q4_1 : iqk_convert_qX_1_q8_1_r8<block_q4_1, Q4_1_Dequantizer>(n, vx, bx, vy, nrc_x); break;
2792-
// case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8<block_q5_0, Q5_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
2793-
// case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8<block_q5_1, Q5_1_Dequantizer<block_q5_1>>(n, vx, bx, vy, nrc_x); break;
2794-
// case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, Q6_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
2795-
// case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, IQ4_NL0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
2796-
// case GGML_TYPE_Q8_0 : iqk_convert_q80_q80_r8(n, vx, bx, vy, nrc_x); break;
2797-
// default: return false;
2798-
//}
2799-
//return true;
2853+
struct DeqQ40 {
2854+
const int8x16_t m8 = vdupq_n_s8(-8);
2855+
const uint8x16_t ml = vdupq_n_s8(0xf);
2856+
inline int8x16x2_t dequant(const block_q4_0& x) const {
2857+
auto bits = vld1q_u8(x.qs);
2858+
return { vaddq_s8(vreinterpretq_s8_u8(vandq_u8(bits, ml)), m8), vaddq_s8(vreinterpretq_s8_u8(vshrq_n_u8(bits, 4)), m8) };
2859+
}
2860+
};
2861+
2862+
struct DeqQ41 {
2863+
const uint8x16_t ml = vdupq_n_s8(0xf);
2864+
inline int8x16x2_t dequant(const block_q4_1& x) const {
2865+
auto bits = vld1q_u8(x.qs);
2866+
return { vreinterpretq_s8_u8(vandq_u8(bits, ml)), vreinterpretq_s8_u8(vshrq_n_u8(bits, 4)) };
2867+
}
2868+
};
2869+
2870+
struct DeqIQ4NL {
2871+
const int8x16_t mt = load_values();
2872+
const uint8x16_t ml = vdupq_n_s8(0xf);
2873+
inline int8x16x2_t dequant(const block_iq4_nl& x) const {
2874+
auto bits = vld1q_u8(x.qs);
2875+
return { vqtbl1q_s8(mt, vandq_u8(bits, ml)), vqtbl1q_s8(mt, vshrq_n_u8(bits, 4)) };
2876+
}
2877+
static inline int8x16_t load_values() { return vld1q_s8(iq4k_values); }
2878+
};
2879+
2880+
struct DeqQ50 {
2881+
2882+
inline int8x16x2_t dequant(const block_q5_0& x) const {
2883+
int8x16x2_t r;
2884+
bits.prepare1(x.qs, r.val);
2885+
auto qh = x.qh;
2886+
r.val[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0))));
2887+
r.val[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2))));
2888+
return r;
2889+
}
2890+
2891+
Q4LegacyBits bits;
2892+
HighBit5Legacy hbits;
2893+
const uint8x16_t mh = vdupq_n_u8(0xf0);
2894+
};
2895+
2896+
struct DeqQ51 {
2897+
2898+
inline int8x16x2_t dequant(const block_q5_1& x) const {
2899+
int8x16x2_t r;
2900+
bits.prepare1(x.qs, r.val);
2901+
auto qh = x.qh;
2902+
r.val[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[0]), vandq_u8(mh, hbits.to_bytes(qh+0))));
2903+
r.val[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(r.val[1]), vandq_u8(mh, hbits.to_bytes(qh+2))));
2904+
return r;
2905+
}
2906+
2907+
Q4LegacyBits bits;
2908+
HighBit5Legacy hbits;
2909+
const uint8x16_t mh = vdupq_n_u8(0x10);
2910+
};
2911+
2912+
struct DeqQ60 {
2913+
2914+
inline int8x16x2_t dequant(const block_q6_0& x) const {
2915+
int8x16x2_t r;
2916+
bits.prepare1(x.qs, r.val);
2917+
auto qh8 = vld1_u8(x.qh);
2918+
auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8);
2919+
r.val[0] = vaddq_s8(vorrq_u8(r.val[0], vandq_u8(qh, hmask)), m32);
2920+
r.val[1] = vaddq_s8(vorrq_u8(r.val[1], vandq_u8(vshrq_n_u8(qh, 2), hmask)), m32);
2921+
return r;
2922+
}
2923+
2924+
Q4LegacyBits bits;
2925+
const int8x16_t m32 = vdupq_n_s8(-32);
2926+
const uint8x16_t hmask = vdupq_n_u8(0x30);
2927+
};
2928+
2929+
struct DeqQ80 {
2930+
inline int8x16x2_t dequant(const block_q8_0& x) const {
2931+
return vld1q_s8_x2(x.qs);
2932+
}
2933+
};
2934+
2935+
template <typename Block, typename Dequantizer>
2936+
void iqk_convert_qX_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
2937+
GGML_ASSERT(n%QK4_0 == 0);
2938+
GGML_ASSERT(nrc_x%8 == 0);
2939+
2940+
const int nb = n/QK8_0;
2941+
2942+
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
2943+
2944+
const Block * x8[8];
2945+
2946+
uint32_t block[8];
2947+
2948+
Dequantizer deq;
2949+
2950+
for (int ix = 0; ix < nrc_x; ix += 8) {
2951+
2952+
for (int k = 0; k < 8; ++k) x8[k] = (const Block *)((const char *)vx + (ix + k)*bx);
2953+
2954+
for (int i = 0; i < nb; ++i) {
2955+
for (int k = 0; k < 8; ++k) {
2956+
y[i].d[k] = x8[k][i].d;
2957+
vst1q_s8_x2((int8_t *)block, deq.dequant(x8[k][i]));
2958+
auto qs = (uint32_t *)y[i].qs;
2959+
for (int l = 0; l < 4; ++l) {
2960+
qs[8*l + k + 0] = block[l + 0];
2961+
qs[8*l + k + 32] = block[l + 4];
2962+
}
2963+
}
2964+
}
2965+
y += nb;
2966+
}
2967+
}
2968+
2969+
template <typename Block, typename Dequantizer>
2970+
void iqk_convert_qX_1_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
2971+
GGML_ASSERT(n%QK4_0 == 0);
2972+
GGML_ASSERT(nrc_x%8 == 0);
2973+
2974+
const int nb = n/QK8_0;
2975+
2976+
block_q8_1_r8 * y = (block_q8_1_r8 *)vy;
2977+
2978+
const Block * x8[8];
2979+
2980+
uint32_t block[8];
2981+
2982+
Dequantizer deq;
2983+
2984+
for (int ix = 0; ix < nrc_x; ix += 8) {
2985+
2986+
for (int k = 0; k < 8; ++k) x8[k] = (const Block *)((const char *)vx + (ix + k)*bx);
2987+
2988+
for (int i = 0; i < nb; ++i) {
2989+
for (int k = 0; k < 8; ++k) {
2990+
y[i].d[k+0] = x8[k][i].d;
2991+
y[i].d[k+8] = x8[k][i].m;
2992+
vst1q_s8_x2((int8_t *)block, deq.dequant(x8[k][i]));
2993+
auto qs = (uint32_t *)y[i].qs;
2994+
for (int l = 0; l < 4; ++l) {
2995+
qs[8*l + k + 0] = block[l + 0];
2996+
qs[8*l + k + 32] = block[l + 4];
2997+
}
2998+
}
2999+
}
3000+
y += nb;
3001+
}
3002+
}
3003+
3004+
}
3005+
3006+
bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
3007+
switch (type) {
3008+
case GGML_TYPE_Q4_0 : iqk_convert_qX_q80_r8<block_q4_0, DeqQ40>(n, vx, bx, vy, nrc_x); break;
3009+
case GGML_TYPE_Q4_1 : iqk_convert_qX_1_q8_1_r8<block_q4_1, DeqQ41>(n, vx, bx, vy, nrc_x); break;
3010+
case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8<block_q5_0, DeqQ50>(n, vx, bx, vy, nrc_x); break;
3011+
case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8<block_q5_1, DeqQ51>(n, vx, bx, vy, nrc_x); break;
3012+
case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, DeqQ60>(n, vx, bx, vy, nrc_x); break;
3013+
case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, DeqIQ4NL>(n, vx, bx, vy, nrc_x); break;
3014+
case GGML_TYPE_Q8_0 : iqk_convert_qX_q80_r8<block_q8_0, DeqQ80>(n, vx, bx, vy, nrc_x); break;
3015+
default: return false;
3016+
}
3017+
return true;
28003018
}
28013019

28023020
bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
28033021

28043022
if (ne00%QK8_0 != 0) return false;
28053023

28063024
auto etypeA = ggml_type(typeA);
2807-
auto expected_typeB = etypeA == GGML_TYPE_Q4_1 || etypeA == GGML_TYPE_Q5_1 ? GGML_TYPE_Q8_1_X4 : GGML_TYPE_Q8_0_X4;
3025+
auto expected_typeB = etypeA == GGML_TYPE_Q4_1 || etypeA == GGML_TYPE_Q5_1 || etypeA == GGML_TYPE_Q8_1 ? GGML_TYPE_Q8_1_X4 : GGML_TYPE_Q8_0_X4;
28083026
if (ggml_type(typeB) != expected_typeB) return false;
28093027

28103028
func16 = nullptr;
@@ -2843,6 +3061,9 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
28433061
case GGML_TYPE_Q8_0_R8:
28443062
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_0_r8_q8_0, kernels);
28453063
break;
3064+
case GGML_TYPE_Q8_1:
3065+
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_1_r8_q8_1, kernels);
3066+
break;
28463067
case GGML_TYPE_IQ4_NL_R4:
28473068
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer, kernels);
28483069
break;

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,16 @@ struct MulMat {
271271
}
272272
#else
273273
switch (type) {
274-
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
275-
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
276-
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
274+
case GGML_TYPE_Q4_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
275+
case GGML_TYPE_Q4_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
276+
case GGML_TYPE_Q5_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
277+
case GGML_TYPE_Q5_1 : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
278+
case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
279+
case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
280+
case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
281+
case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
282+
case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
283+
case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
277284
default: break;
278285
}
279286
#endif
@@ -913,6 +920,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
913920
case GGML_TYPE_Q5_0_R4:
914921
case GGML_TYPE_Q6_0_R4:
915922
case GGML_TYPE_Q8_0_R8:
923+
case GGML_TYPE_Q8_1:
916924
case GGML_TYPE_IQ4_NL_R4:
917925
return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16);
918926
case GGML_TYPE_IQ1_BN:

src/llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18722,7 +18722,7 @@ static std::pair<ggml_type, int> interleaved_properties(ggml_type type) {
1872218722
{ GGML_TYPE_IQ5_KS_R4, { GGML_TYPE_IQ5_KS, 4} },
1872318723
{ GGML_TYPE_IQ5_K_R4, { GGML_TYPE_IQ5_K, 4} },
1872418724
{ GGML_TYPE_Q8_KV_R8, { GGML_TYPE_Q8_KV, 8} },
18725-
{ GGML_TYPE_Q8_K_R8, { GGML_TYPE_Q8_K, 8} },
18725+
{ GGML_TYPE_Q8_K_R8, { GGML_TYPE_Q8_0, 8} },
1872618726
{ GGML_TYPE_BF16_R16, { GGML_TYPE_BF16, 16} },
1872718727
};
1872818728
if (auto it = k_map.find(type); it != k_map.end()) return it->second;

0 commit comments

Comments
 (0)