Skip to content

Commit ca20df1

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents b407232 + 07673c6 commit ca20df1

File tree

2 files changed

+116
-36
lines changed

2 files changed

+116
-36
lines changed

ggml/src/iqk/iqk_gemm_1bit.cpp

Lines changed: 115 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
namespace {
1111

12-
#ifdef __AVX2__
1312
static const uint64_t iq1s_grid_us[2048] = {
1413
0x0000000000000000, 0x0000000000000002, 0x0000000000000101, 0x0000000000000200,
1514
0x0000000000000202, 0x0000000000010001, 0x0000000000010101, 0x0000000000020000,
@@ -524,8 +523,8 @@ static const uint64_t iq1s_grid_us[2048] = {
524523
0x0202020202000002, 0x0202020202000200, 0x0202020202000202, 0x0202020202010101,
525524
0x0202020202020000, 0x0202020202020002, 0x0202020202020200, 0x0202020202020202,
526525
};
527-
#else
528-
static const uint32_t iq1s_grid_us[2048] = {
526+
#ifdef __aarch64__
527+
static const uint32_t iq1s_grid_us_neon[2048] = {
529528
0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
530529
0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
531530
0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,
@@ -2336,22 +2335,22 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
23362335
auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1)));
23372336
signs = vadd_s16(vdup_n_s16(-8), signs);
23382337
auto delta4 = vmulq_f32(vdupq_n_f32(0.125f), vcvtq_f32_s32(vmull_s16(signs, scales4)));
2339-
qx[0] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
2340-
iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
2341-
iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
2342-
iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
2343-
qx[2] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)],
2344-
iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)],
2345-
iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)],
2346-
iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]});
2347-
qx[4] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)],
2348-
iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)],
2349-
iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)],
2350-
iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]});
2351-
qx[6] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)],
2352-
iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)],
2353-
iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)],
2354-
iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]});
2338+
qx[0] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
2339+
iq1s_grid_us_neon[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
2340+
iq1s_grid_us_neon[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
2341+
iq1s_grid_us_neon[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
2342+
qx[2] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)],
2343+
iq1s_grid_us_neon[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)],
2344+
iq1s_grid_us_neon[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)],
2345+
iq1s_grid_us_neon[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]});
2346+
qx[4] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)],
2347+
iq1s_grid_us_neon[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)],
2348+
iq1s_grid_us_neon[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)],
2349+
iq1s_grid_us_neon[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]});
2350+
qx[6] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)],
2351+
iq1s_grid_us_neon[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)],
2352+
iq1s_grid_us_neon[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)],
2353+
iq1s_grid_us_neon[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]});
23552354
qx[1] = vandq_u8(vshrq_n_u8(qx[0], 4), mask); qx[0] = vandq_u8(qx[0], mask);
23562355
qx[3] = vandq_u8(vshrq_n_u8(qx[2], 4), mask); qx[2] = vandq_u8(qx[2], mask);
23572356
qx[5] = vandq_u8(vshrq_n_u8(qx[4], 4), mask); qx[4] = vandq_u8(qx[4], mask);
@@ -2409,22 +2408,22 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI
24092408
auto idxh = uint32x4_t{qh[0], qh[0] >> 4, qh[1], qh[1] >> 4};
24102409
auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(idxh, ms), ms), vdupq_n_u8(1)));
24112410
signs = vaddq_s8(signs, vdupq_n_s8(-8));
2412-
qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
2413-
iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
2414-
iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
2415-
iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
2416-
qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)],
2417-
iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)],
2418-
iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)],
2419-
iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]});
2420-
qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)],
2421-
iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)],
2422-
iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)],
2423-
iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]});
2424-
qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)],
2425-
iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)],
2426-
iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)],
2427-
iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]});
2411+
qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
2412+
iq1s_grid_us_neon[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
2413+
iq1s_grid_us_neon[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
2414+
iq1s_grid_us_neon[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
2415+
qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)],
2416+
iq1s_grid_us_neon[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)],
2417+
iq1s_grid_us_neon[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)],
2418+
iq1s_grid_us_neon[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]});
2419+
qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)],
2420+
iq1s_grid_us_neon[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)],
2421+
iq1s_grid_us_neon[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)],
2422+
iq1s_grid_us_neon[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]});
2423+
qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)],
2424+
iq1s_grid_us_neon[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)],
2425+
iq1s_grid_us_neon[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)],
2426+
iq1s_grid_us_neon[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]});
24282427
auto shuffle = shuffle0;
24292428
for (int j = 0; j < 4; ++j) {
24302429
auto s = vqtbl1q_s8(signs, shuffle);
@@ -2583,6 +2582,81 @@ void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info,
25832582
}
25842583
}
25852584

2585+
template <int nrc_y>
2586+
void mul_mat_iq1_m_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
2587+
GGML_ASSERT(n%QK_K == 0);
2588+
Q8<nrc_y, block_q8_K> q8(info);
2589+
int8x16x2_t qx[8];
2590+
int32x4x4_t scales;
2591+
float32x4_t acc[nrc_y] = {};
2592+
uint8x16x2_t scale_shuffle = {vreinterpretq_u8_u64(uint64x2_t{0x0100010001000100, 0x0302030203020302}),
2593+
vreinterpretq_u8_u64(uint64x2_t{0x0504050405040504, 0x0706070607060706})};
2594+
uint64x2x2_t delta_mask = {uint64x2_t{0x0008, 0x0080}, uint64x2_t{0x0800, 0x8000}};
2595+
iq1m_scale_t block_scale;
2596+
for (int ix = 0; ix < nrc_x; ++ix) {
2597+
auto iq1m = (const block_iq1_m *)((const char *)vx + ix*bx);
2598+
for (int ibl = 0; ibl < n/QK_K; ++ibl) {
2599+
const uint16_t * sc = (const uint16_t *)iq1m[ibl].scales; // 4 x uint16_t, each containing 4 scales
2600+
block_scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
2601+
float d = GGML_FP16_TO_FP32(block_scale.f16);
2602+
auto qs = iq1m[ibl].qs;
2603+
auto qh = iq1m[ibl].qh;
2604+
auto aux8 = vld1_u8(iq1m[ibl].scales);
2605+
auto aux16 = vcombine_u8(aux8, aux8);
2606+
uint16x8x2_t sc16 = { vreinterpretq_u16_u8(vqtbl1q_u8(aux16, scale_shuffle.val[0])), vreinterpretq_u16_u8(vqtbl1q_u8(aux16, scale_shuffle.val[1])) };
2607+
sc16.val[0] = vmulq_u16(vandq_u16(sc16.val[0], vdupq_n_u64(0x0e0001c000380007)), vdupq_n_u64(0x0001000800400200));
2608+
sc16.val[1] = vmulq_u16(vandq_u16(sc16.val[1], vdupq_n_u64(0x0e0001c000380007)), vdupq_n_u64(0x0001000800400200));
2609+
sc16.val[0] = vaddq_u16(vshrq_n_u16(sc16.val[0], 8), vdupq_n_u16(1));
2610+
sc16.val[1] = vaddq_u16(vshrq_n_u16(sc16.val[1], 8), vdupq_n_u16(1));
2611+
scales.val[0] = vmovl_s16(vget_low_s16 (sc16.val[0]));
2612+
scales.val[1] = vmovl_s16(vget_high_s16(sc16.val[0]));
2613+
scales.val[2] = vmovl_s16(vget_low_s16 (sc16.val[1]));
2614+
scales.val[3] = vmovl_s16(vget_high_s16(sc16.val[1]));
2615+
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
2616+
qx[2*ib64+0] = {vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)],
2617+
iq1s_grid_us[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)]}),
2618+
vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
2619+
iq1s_grid_us[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)]})};
2620+
qx[2*ib64+1] = {vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)],
2621+
iq1s_grid_us[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)]}),
2622+
vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
2623+
iq1s_grid_us[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)]})};
2624+
auto qh16 = (const uint16_t *)qh;
2625+
auto h1 = vdupq_n_u64(qh16[0]);
2626+
auto h2 = vdupq_n_u64(qh16[1]);
2627+
auto delta1 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h1, delta_mask.val[0]), delta_mask.val[0])), vdupq_n_s8(1)));
2628+
auto delta2 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h1, delta_mask.val[1]), delta_mask.val[1])), vdupq_n_s8(1)));
2629+
qx[2*ib64+0].val[0] = vsubq_s8(vshlq_n_s8(qx[2*ib64+0].val[0], 3), delta1);
2630+
qx[2*ib64+0].val[1] = vsubq_s8(vshlq_n_s8(qx[2*ib64+0].val[1], 3), delta2);
2631+
delta1 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h2, delta_mask.val[0]), delta_mask.val[0])), vdupq_n_s8(1)));
2632+
delta2 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h2, delta_mask.val[1]), delta_mask.val[1])), vdupq_n_s8(1)));
2633+
qx[2*ib64+1].val[0] = vsubq_s8(vshlq_n_s8(qx[2*ib64+1].val[0], 3), delta1);
2634+
qx[2*ib64+1].val[1] = vsubq_s8(vshlq_n_s8(qx[2*ib64+1].val[1], 3), delta2);
2635+
qs += 8;
2636+
qh += 4;
2637+
}
2638+
for (int iy = 0; iy < nrc_y; ++iy) {
2639+
auto sumi = vdupq_n_s32(0);
2640+
for (int j = 0; j < 4; ++j) {
2641+
auto y1 = q8.load_quants(iy, ibl, 2*j+0);
2642+
auto dot1 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+0].val[0], y1.val[0]);
2643+
auto dot2 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+0].val[1], y1.val[1]);
2644+
auto y2 = q8.load_quants(iy, ibl, 2*j+1);
2645+
auto dot3 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+1].val[0], y2.val[0]);
2646+
auto dot4 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+1].val[1], y2.val[1]);
2647+
auto dot = vpaddq_s32(vpaddq_s32(dot1, dot2), vpaddq_s32(dot3, dot4));
2648+
sumi = vmlaq_s32(sumi, dot, scales.val[j]);
2649+
}
2650+
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, ibl)), vcvtq_f32_s32(sumi));
2651+
}
2652+
}
2653+
for (int iy = 0; iy < nrc_y; ++iy) {
2654+
info.store(ix, iy, 0.125f*vaddvq_f32(acc[iy]));
2655+
acc[iy] = vdupq_n_f32(0.f);
2656+
}
2657+
}
2658+
}
2659+
25862660
inline float convert_to_q8_k_r8(float d0, const int8x16x2_t * qx, const int8_t * scales, uint32_t * block, uint32_t * q8_k) {
25872661
auto max_i16 = vdupq_n_u16(0);
25882662
int16x8x4_t q[8];
@@ -2774,6 +2848,12 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
27742848
func16 = mul_mat_iq1_s_r4_q8_1<16>;
27752849
expected_Btype = GGML_TYPE_Q8_K128;
27762850
break;
2851+
case GGML_TYPE_IQ1_M:
2852+
if (ne00%QK_K != 0) return false;
2853+
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_q8_K, funcs);
2854+
func16 = mul_mat_iq1_m_q8_K<16>;
2855+
expected_Btype = GGML_TYPE_Q8_K;
2856+
break;
27772857
case GGML_TYPE_IQ1_M_R4:
27782858
if (ne00%128 != 0) return false;
27792859
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_r4_q8_0, funcs);

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ struct MulMat {
279279
case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
280280
case GGML_TYPE_Q6_K : return nrc_y >= 64 ? GGML_TYPE_Q8_0_R8 : type;
281281
case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
282-
case GGML_TYPE_IQ1_M : return nrc_y >= 8 ? GGML_TYPE_Q8_K_R8 : type;
282+
case GGML_TYPE_IQ1_M : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
283283
case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
284284
case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;
285285
case GGML_TYPE_IQ2_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;

0 commit comments

Comments
 (0)