|
9 | 9 |
|
10 | 10 | namespace { |
11 | 11 |
|
12 | | -#ifdef __AVX2__ |
13 | 12 | static const uint64_t iq1s_grid_us[2048] = { |
14 | 13 | 0x0000000000000000, 0x0000000000000002, 0x0000000000000101, 0x0000000000000200, |
15 | 14 | 0x0000000000000202, 0x0000000000010001, 0x0000000000010101, 0x0000000000020000, |
@@ -524,8 +523,8 @@ static const uint64_t iq1s_grid_us[2048] = { |
524 | 523 | 0x0202020202000002, 0x0202020202000200, 0x0202020202000202, 0x0202020202010101, |
525 | 524 | 0x0202020202020000, 0x0202020202020002, 0x0202020202020200, 0x0202020202020202, |
526 | 525 | }; |
527 | | -#else |
528 | | -static const uint32_t iq1s_grid_us[2048] = { |
| 526 | +#ifdef __aarch64__ |
| 527 | +static const uint32_t iq1s_grid_us_neon[2048] = { |
529 | 528 | 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, |
530 | 529 | 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, |
531 | 530 | 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, |
@@ -2336,22 +2335,22 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI |
2336 | 2335 | auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1))); |
2337 | 2336 | signs = vadd_s16(vdup_n_s16(-8), signs); |
2338 | 2337 | auto delta4 = vmulq_f32(vdupq_n_f32(0.125f), vcvtq_f32_s32(vmull_s16(signs, scales4))); |
2339 | | - qx[0] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], |
2340 | | - iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)], |
2341 | | - iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], |
2342 | | - iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]}); |
2343 | | - qx[2] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)], |
2344 | | - iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)], |
2345 | | - iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)], |
2346 | | - iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]}); |
2347 | | - qx[4] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)], |
2348 | | - iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)], |
2349 | | - iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)], |
2350 | | - iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]}); |
2351 | | - qx[6] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)], |
2352 | | - iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)], |
2353 | | - iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)], |
2354 | | - iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]}); |
| 2338 | + qx[0] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], |
| 2339 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)], |
| 2340 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], |
| 2341 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]}); |
| 2342 | + qx[2] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)], |
| 2343 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)], |
| 2344 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)], |
| 2345 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]}); |
| 2346 | + qx[4] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)], |
| 2347 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)], |
| 2348 | + iq1s_grid_us_neon[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)], |
| 2349 | + iq1s_grid_us_neon[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]}); |
| 2350 | + qx[6] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)], |
| 2351 | + iq1s_grid_us_neon[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)], |
| 2352 | + iq1s_grid_us_neon[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)], |
| 2353 | + iq1s_grid_us_neon[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]}); |
2355 | 2354 | qx[1] = vandq_u8(vshrq_n_u8(qx[0], 4), mask); qx[0] = vandq_u8(qx[0], mask); |
2356 | 2355 | qx[3] = vandq_u8(vshrq_n_u8(qx[2], 4), mask); qx[2] = vandq_u8(qx[2], mask); |
2357 | 2356 | qx[5] = vandq_u8(vshrq_n_u8(qx[4], 4), mask); qx[4] = vandq_u8(qx[4], mask); |
@@ -2409,22 +2408,22 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI |
2409 | 2408 | auto idxh = uint32x4_t{qh[0], qh[0] >> 4, qh[1], qh[1] >> 4}; |
2410 | 2409 | auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(idxh, ms), ms), vdupq_n_u8(1))); |
2411 | 2410 | signs = vaddq_s8(signs, vdupq_n_s8(-8)); |
2412 | | - qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], |
2413 | | - iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)], |
2414 | | - iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], |
2415 | | - iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]}); |
2416 | | - qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)], |
2417 | | - iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)], |
2418 | | - iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)], |
2419 | | - iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]}); |
2420 | | - qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)], |
2421 | | - iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)], |
2422 | | - iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)], |
2423 | | - iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]}); |
2424 | | - qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)], |
2425 | | - iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)], |
2426 | | - iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)], |
2427 | | - iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]}); |
| 2411 | + qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)], |
| 2412 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)], |
| 2413 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)], |
| 2414 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]}); |
| 2415 | + qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)], |
| 2416 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)], |
| 2417 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)], |
| 2418 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]}); |
| 2419 | + qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)], |
| 2420 | + iq1s_grid_us_neon[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)], |
| 2421 | + iq1s_grid_us_neon[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)], |
| 2422 | + iq1s_grid_us_neon[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]}); |
| 2423 | + qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us_neon[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)], |
| 2424 | + iq1s_grid_us_neon[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)], |
| 2425 | + iq1s_grid_us_neon[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)], |
| 2426 | + iq1s_grid_us_neon[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]}); |
2428 | 2427 | auto shuffle = shuffle0; |
2429 | 2428 | for (int j = 0; j < 4; ++j) { |
2430 | 2429 | auto s = vqtbl1q_s8(signs, shuffle); |
@@ -2583,6 +2582,81 @@ void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, |
2583 | 2582 | } |
2584 | 2583 | } |
2585 | 2584 |
|
| 2585 | +template <int nrc_y> |
| 2586 | +void mul_mat_iq1_m_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { |
| 2587 | + GGML_ASSERT(n%QK_K == 0); |
| 2588 | + Q8<nrc_y, block_q8_K> q8(info); |
| 2589 | + int8x16x2_t qx[8]; |
| 2590 | + int32x4x4_t scales; |
| 2591 | + float32x4_t acc[nrc_y] = {}; |
| 2592 | + uint8x16x2_t scale_shuffle = {vreinterpretq_u8_u64(uint64x2_t{0x0100010001000100, 0x0302030203020302}), |
| 2593 | + vreinterpretq_u8_u64(uint64x2_t{0x0504050405040504, 0x0706070607060706})}; |
| 2594 | + uint64x2x2_t delta_mask = {uint64x2_t{0x0008, 0x0080}, uint64x2_t{0x0800, 0x8000}}; |
| 2595 | + iq1m_scale_t block_scale; |
| 2596 | + for (int ix = 0; ix < nrc_x; ++ix) { |
| 2597 | + auto iq1m = (const block_iq1_m *)((const char *)vx + ix*bx); |
| 2598 | + for (int ibl = 0; ibl < n/QK_K; ++ibl) { |
| 2599 | + const uint16_t * sc = (const uint16_t *)iq1m[ibl].scales; // 4 x uint16_t, each containing 4 scales |
| 2600 | + block_scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); |
| 2601 | + float d = GGML_FP16_TO_FP32(block_scale.f16); |
| 2602 | + auto qs = iq1m[ibl].qs; |
| 2603 | + auto qh = iq1m[ibl].qh; |
| 2604 | + auto aux8 = vld1_u8(iq1m[ibl].scales); |
| 2605 | + auto aux16 = vcombine_u8(aux8, aux8); |
| 2606 | + uint16x8x2_t sc16 = { vreinterpretq_u16_u8(vqtbl1q_u8(aux16, scale_shuffle.val[0])), vreinterpretq_u16_u8(vqtbl1q_u8(aux16, scale_shuffle.val[1])) }; |
| 2607 | + sc16.val[0] = vmulq_u16(vandq_u16(sc16.val[0], vdupq_n_u64(0x0e0001c000380007)), vdupq_n_u64(0x0001000800400200)); |
| 2608 | + sc16.val[1] = vmulq_u16(vandq_u16(sc16.val[1], vdupq_n_u64(0x0e0001c000380007)), vdupq_n_u64(0x0001000800400200)); |
| 2609 | + sc16.val[0] = vaddq_u16(vshrq_n_u16(sc16.val[0], 8), vdupq_n_u16(1)); |
| 2610 | + sc16.val[1] = vaddq_u16(vshrq_n_u16(sc16.val[1], 8), vdupq_n_u16(1)); |
| 2611 | + scales.val[0] = vmovl_s16(vget_low_s16 (sc16.val[0])); |
| 2612 | + scales.val[1] = vmovl_s16(vget_high_s16(sc16.val[0])); |
| 2613 | + scales.val[2] = vmovl_s16(vget_low_s16 (sc16.val[1])); |
| 2614 | + scales.val[3] = vmovl_s16(vget_high_s16(sc16.val[1])); |
| 2615 | + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { |
| 2616 | + qx[2*ib64+0] = {vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)], |
| 2617 | + iq1s_grid_us[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)]}), |
| 2618 | + vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)], |
| 2619 | + iq1s_grid_us[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)]})}; |
| 2620 | + qx[2*ib64+1] = {vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)], |
| 2621 | + iq1s_grid_us[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)]}), |
| 2622 | + vreinterpretq_s8_u64(uint64x2_t{iq1s_grid_us[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)], |
| 2623 | + iq1s_grid_us[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)]})}; |
| 2624 | + auto qh16 = (const uint16_t *)qh; |
| 2625 | + auto h1 = vdupq_n_u64(qh16[0]); |
| 2626 | + auto h2 = vdupq_n_u64(qh16[1]); |
| 2627 | + auto delta1 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h1, delta_mask.val[0]), delta_mask.val[0])), vdupq_n_s8(1))); |
| 2628 | + auto delta2 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h1, delta_mask.val[1]), delta_mask.val[1])), vdupq_n_s8(1))); |
| 2629 | + qx[2*ib64+0].val[0] = vsubq_s8(vshlq_n_s8(qx[2*ib64+0].val[0], 3), delta1); |
| 2630 | + qx[2*ib64+0].val[1] = vsubq_s8(vshlq_n_s8(qx[2*ib64+0].val[1], 3), delta2); |
| 2631 | + delta1 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h2, delta_mask.val[0]), delta_mask.val[0])), vdupq_n_s8(1))); |
| 2632 | + delta2 = vsubq_s8(vdupq_n_s8(8), vorrq_s8(vreinterpretq_s8_u64(vceqq_u64(vandq_u64(h2, delta_mask.val[1]), delta_mask.val[1])), vdupq_n_s8(1))); |
| 2633 | + qx[2*ib64+1].val[0] = vsubq_s8(vshlq_n_s8(qx[2*ib64+1].val[0], 3), delta1); |
| 2634 | + qx[2*ib64+1].val[1] = vsubq_s8(vshlq_n_s8(qx[2*ib64+1].val[1], 3), delta2); |
| 2635 | + qs += 8; |
| 2636 | + qh += 4; |
| 2637 | + } |
| 2638 | + for (int iy = 0; iy < nrc_y; ++iy) { |
| 2639 | + auto sumi = vdupq_n_s32(0); |
| 2640 | + for (int j = 0; j < 4; ++j) { |
| 2641 | + auto y1 = q8.load_quants(iy, ibl, 2*j+0); |
| 2642 | + auto dot1 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+0].val[0], y1.val[0]); |
| 2643 | + auto dot2 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+0].val[1], y1.val[1]); |
| 2644 | + auto y2 = q8.load_quants(iy, ibl, 2*j+1); |
| 2645 | + auto dot3 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+1].val[0], y2.val[0]); |
| 2646 | + auto dot4 = ggml_vdotq_s32(vdupq_n_s32(0), qx[2*j+1].val[1], y2.val[1]); |
| 2647 | + auto dot = vpaddq_s32(vpaddq_s32(dot1, dot2), vpaddq_s32(dot3, dot4)); |
| 2648 | + sumi = vmlaq_s32(sumi, dot, scales.val[j]); |
| 2649 | + } |
| 2650 | + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, ibl)), vcvtq_f32_s32(sumi)); |
| 2651 | + } |
| 2652 | + } |
| 2653 | + for (int iy = 0; iy < nrc_y; ++iy) { |
| 2654 | + info.store(ix, iy, 0.125f*vaddvq_f32(acc[iy])); |
| 2655 | + acc[iy] = vdupq_n_f32(0.f); |
| 2656 | + } |
| 2657 | + } |
| 2658 | +} |
| 2659 | + |
2586 | 2660 | inline float convert_to_q8_k_r8(float d0, const int8x16x2_t * qx, const int8_t * scales, uint32_t * block, uint32_t * q8_k) { |
2587 | 2661 | auto max_i16 = vdupq_n_u16(0); |
2588 | 2662 | int16x8x4_t q[8]; |
@@ -2774,6 +2848,12 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, |
2774 | 2848 | func16 = mul_mat_iq1_s_r4_q8_1<16>; |
2775 | 2849 | expected_Btype = GGML_TYPE_Q8_K128; |
2776 | 2850 | break; |
| 2851 | + case GGML_TYPE_IQ1_M: |
| 2852 | + if (ne00%QK_K != 0) return false; |
| 2853 | + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_q8_K, funcs); |
| 2854 | + func16 = mul_mat_iq1_m_q8_K<16>; |
| 2855 | + expected_Btype = GGML_TYPE_Q8_K; |
| 2856 | + break; |
2777 | 2857 | case GGML_TYPE_IQ1_M_R4: |
2778 | 2858 | if (ne00%128 != 0) return false; |
2779 | 2859 | IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_r4_q8_0, funcs); |
|
0 commit comments