@@ -2515,6 +2515,154 @@ void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info,
25152515 }
25162516 }
25172517}
2518+
2519+ inline float convert_to_q8_k_r8 (float d0, const int8x16x2_t * qx, const int8_t * scales, uint32_t * block, uint32_t * q8_k) {
2520+ auto max_i16 = vdupq_n_u16 (0 );
2521+ int16x8x4_t q[8 ];
2522+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
2523+ auto scale_l = vdup_n_s8 (scales[2 *ib32+0 ]);
2524+ auto scale_h = vdup_n_s8 (scales[2 *ib32+1 ]);
2525+ q[ib32].val [0 ] = vmull_s8 (scale_l, vget_low_s8 (qx[ib32].val [0 ]));
2526+ q[ib32].val [1 ] = vmull_s8 (scale_l, vget_high_s8 (qx[ib32].val [0 ]));
2527+ q[ib32].val [2 ] = vmull_s8 (scale_h, vget_low_s8 (qx[ib32].val [1 ]));
2528+ q[ib32].val [3 ] = vmull_s8 (scale_h, vget_high_s8 (qx[ib32].val [1 ]));
2529+ max_i16 = vmaxq_u16 (max_i16, vmaxq_u16 (vabsq_s16 (q[ib32].val [0 ]), vabsq_s16 (q[ib32].val [1 ])));
2530+ max_i16 = vmaxq_u16 (max_i16, vmaxq_u16 (vabsq_s16 (q[ib32].val [2 ]), vabsq_s16 (q[ib32].val [3 ])));
2531+ }
2532+ uint16_t imax = vmaxvq_u16 (max_i16);
2533+ if (!imax) {
2534+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) for (int l = 0 ; l < 8 ; ++l) q8_k[64 *ib32 + 8 *l] = 0 ;
2535+ return 0 .f ;
2536+ }
2537+ float dnew = float (imax) * d0;
2538+ // auto max_u32 = vmaxq_u32(vmovl_u16(vget_low_u16(max_i16)), vmovl_u16(vget_high_u16(max_i16)));
2539+ // auto max_f32 = vcvtq_f32_u32(max_u32);
2540+ // auto dnew = vmaxvq_f32(max_f32) * d0;
2541+ bool needs_scaling = true ;
2542+ if (dnew <= 1 .f ) {
2543+ dnew = 1 .f ; needs_scaling = false ;
2544+ }
2545+ auto scale = vdupq_n_f32 (1 /dnew);
2546+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
2547+ if (needs_scaling) {
2548+ for (int l = 0 ; l < 4 ; ++l) {
2549+ auto i1 = vcvtnq_s32_f32 (vmulq_f32 (scale, vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (q[ib32].val [l])))));
2550+ auto i2 = vcvtnq_s32_f32 (vmulq_f32 (scale, vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (q[ib32].val [l])))));
2551+ q[ib32].val [l] = vcombine_s16 (vmovn_s32 (i1), vmovn_s32 (i2));
2552+ }
2553+ }
2554+ for (int l = 0 ; l < 2 ; ++l) {
2555+ auto s8 = vcombine_s8 (vmovn_s16 (q[ib32].val [2 *l+0 ]), vmovn_s16 (q[ib32].val [2 *l+1 ]));
2556+ vst1q_s8 ((int8_t *)block + 16 *l, s8);
2557+ }
2558+ auto qb = q8_k + 64 *ib32;
2559+ for (int l = 0 ; l < 8 ; ++l) {
2560+ qb[8 *l] = block[l];
2561+ }
2562+ }
2563+ return dnew;
2564+ }
2565+
2566+ void iqk_convert_iq1_s_q8_k_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
2567+ GGML_ASSERT (n%QK_K == 0 );
2568+ GGML_ASSERT (nrc_x%8 == 0 );
2569+
2570+ int nb = n/QK_K;
2571+
2572+ const block_iq1_s * x8[8 ];
2573+
2574+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
2575+
2576+ int8_t ls[16 ];
2577+
2578+ uint32_t block[8 ];
2579+
2580+ int8x16x2_t qx[8 ];
2581+
2582+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
2583+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq1_s *)((const char *)vx + (ix + k)*bx);
2584+ for (int i = 0 ; i < nb; ++i) {
2585+ for (int k = 0 ; k < 8 ; ++k) {
2586+ float d = 0 .125f * GGML_FP16_TO_FP32 (x8[k][i].d );
2587+ auto qs = x8[k][i].qs ;
2588+ auto qh = x8[k][i].qh ;
2589+ int8x16x2_t value;
2590+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
2591+ ls[2 *ib32 + 0 ] = ls[2 *ib32 + 1 ] = (2 *((qh[ib32] >> 12 ) & 7 ) + 1 );
2592+ value.val [0 ] = vreinterpretq_s8_u64 (uint64x2_t {iq1s_grid[qs[0 ] | ((qh[ib32] << 8 ) & 0x700 )], iq1s_grid[qs[1 ] | ((qh[ib32] << 5 ) & 0x700 )]});
2593+ value.val [1 ] = vreinterpretq_s8_u64 (uint64x2_t {iq1s_grid[qs[2 ] | ((qh[ib32] << 2 ) & 0x700 )], iq1s_grid[qs[3 ] | ((qh[ib32] >> 1 ) & 0x700 )]});
2594+ value.val [0 ] = vshlq_n_s8 (vaddq_s8 (value.val [0 ], vdupq_n_s8 (1 )), 3 );
2595+ value.val [1 ] = vshlq_n_s8 (vaddq_s8 (value.val [1 ], vdupq_n_s8 (1 )), 3 );
2596+ auto delta = vdupq_n_s8 (qh[ib32] & 0x8000 ? -9 : -7 );
2597+ qx[ib32].val [0 ] = vaddq_s8 (value.val [0 ], delta);
2598+ qx[ib32].val [1 ] = vaddq_s8 (value.val [1 ], delta);
2599+ qs += 4 ;
2600+ }
2601+ float dnew = convert_to_q8_k_r8 (1 .f /126 , qx, ls, block, (uint32_t *)y[i].qs + k);
2602+ y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
2603+ }
2604+ }
2605+ y += nb;
2606+ }
2607+ }
2608+
2609+ void iqk_convert_iq1_m_q8_k_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
2610+ GGML_ASSERT (n%QK_K == 0 );
2611+ GGML_ASSERT (nrc_x%8 == 0 );
2612+
2613+ int nb = n/QK_K;
2614+
2615+ const block_iq1_m * x8[8 ];
2616+
2617+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
2618+
2619+ int8_t ls[16 ];
2620+
2621+ uint32_t block[8 ];
2622+
2623+ int8x16x2_t qx[8 ];
2624+
2625+ uint32x4x2_t mask = {uint32x4_t {0x00000008 , 0x00000008 , 0x00000080 , 0x00000080 }, uint32x4_t {0x00080000 , 0x00080000 , 0x00800000 , 0x00800000 }};
2626+
2627+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
2628+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq1_m *)((const char *)vx + (ix + k)*bx);
2629+ for (int i = 0 ; i < nb; ++i) {
2630+ for (int k = 0 ; k < 8 ; ++k) {
2631+ const uint16_t * sc = (const uint16_t *)x8[k][i].scales ;
2632+ iq1m_scale_t scale;
2633+ scale.u16 = (sc[0 ] >> 12 ) | ((sc[1 ] >> 8 ) & 0x00f0 ) | ((sc[2 ] >> 4 ) & 0x0f00 ) | (sc[3 ] & 0xf000 );
2634+ float d = 0 .125f * GGML_FP16_TO_FP32 (scale.f16 );
2635+ auto qs = x8[k][i].qs ;
2636+ auto qh = x8[k][i].qh ;
2637+ int8x16x2_t value;
2638+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
2639+ ls[2 *ib32 + 0 ] = (2 *((sc[ib32/2 ] >> (6 *(ib32%2 )+0 )) & 0x7 ) + 1 );
2640+ ls[2 *ib32 + 1 ] = (2 *((sc[ib32/2 ] >> (6 *(ib32%2 )+3 )) & 0x7 ) + 1 );
2641+ // value = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | ((qh[1] << 8) & 0x700)],
2642+ // iq1s_grid[qs[1] | ((qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | ((qh[0] << 8) & 0x700)]);
2643+ value.val [0 ] = vreinterpretq_s8_u64 (uint64x2_t {iq1s_grid[qs[0 ] | ((qh[0 ] << 8 ) & 0x700 )], iq1s_grid[qs[1 ] | ((qh[0 ] << 4 ) & 0x700 )]});
2644+ value.val [1 ] = vreinterpretq_s8_u64 (uint64x2_t {iq1s_grid[qs[2 ] | ((qh[1 ] << 8 ) & 0x700 )], iq1s_grid[qs[3 ] | ((qh[1 ] << 4 ) & 0x700 )]});
2645+ value.val [0 ] = vshlq_n_s8 (vaddq_s8 (value.val [0 ], vdupq_n_s8 (1 )), 3 );
2646+ value.val [1 ] = vshlq_n_s8 (vaddq_s8 (value.val [1 ], vdupq_n_s8 (1 )), 3 );
2647+
2648+ auto aux = vdupq_n_u32 (qh[0 ] | qh[1 ] << 16 );
2649+ uint32x4x2_t delta_mask{ vceqq_u32 (vandq_u32 (aux, mask.val [0 ]), mask.val [0 ]), vceqq_u32 (vandq_u32 (aux, mask.val [1 ]), mask.val [1 ]) };
2650+ uint8x16x2_t delta{ vaddq_s8 (vdupq_n_s8 (7 ), vandq_s8 (vdupq_n_s8 (2 ), vreinterpretq_s8_u32 (delta_mask.val [0 ]))),
2651+ vaddq_s8 (vdupq_n_s8 (7 ), vandq_s8 (vdupq_n_s8 (2 ), vreinterpretq_s8_u32 (delta_mask.val [1 ]))) };
2652+ qx[ib32].val [0 ] = vsubq_s8 (value.val [0 ], delta.val [0 ]);
2653+ qx[ib32].val [1 ] = vsubq_s8 (value.val [1 ], delta.val [1 ]);
2654+
2655+ qs += 4 ;
2656+ qh += 2 ;
2657+ }
2658+ float dnew = convert_to_q8_k_r8 (1 .f /126 , qx, ls, block, (uint32_t *)y[i].qs + k);
2659+ y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
2660+ }
2661+ }
2662+ y += nb;
2663+ }
2664+ }
2665+
25182666}
25192667
25202668bool iqk_set_kernels_1bit (int ne00, int typeA, int typeB, std::array<mul_mat_t , IQK_MAX_NY>& funcs, mul_mat_t & func16) {
@@ -2573,8 +2721,14 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
25732721
25742722}
25752723
2576- bool iqk_convert_1bit_q80_r8 ([[maybe_unused]] int type, [[maybe_unused]] int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, [[maybe_unused]] int nrc_x) {
2577- return false ;
2724+ bool iqk_convert_1bit_q80_r8 (int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
2725+ if (n%QK_K != 0 || nrc_x%8 != 0 ) return false ;
2726+ switch (ggml_type (type)) {
2727+ case GGML_TYPE_IQ1_S: iqk_convert_iq1_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break ;
2728+ case GGML_TYPE_IQ1_M: iqk_convert_iq1_m_q8_k_r8 (n, vx, bx, vy, nrc_x); break ;
2729+ default : return false ;
2730+ }
2731+ return true ;
25782732}
25792733
25802734#endif
0 commit comments