@@ -3348,20 +3348,328 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
33483348 }
33493349}
33503350
3351+ inline float convert_to_q8_k_r8 (float d0, const int8x16x2_t * qx, const int8_t * scales, uint32_t * block, uint32_t * q8_k) {
3352+ auto max_i16 = vdupq_n_u16 (0 );
3353+ int16x8x4_t q[8 ];
3354+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
3355+ auto scale_l = vdup_n_s8 (scales[2 *ib32+0 ]);
3356+ auto scale_h = vdup_n_s8 (scales[2 *ib32+1 ]);
3357+ q[ib32].val [0 ] = vmull_s8 (scale_l, vget_low_s8 (qx[ib32].val [0 ]));
3358+ q[ib32].val [1 ] = vmull_s8 (scale_l, vget_high_s8 (qx[ib32].val [0 ]));
3359+ q[ib32].val [2 ] = vmull_s8 (scale_h, vget_low_s8 (qx[ib32].val [1 ]));
3360+ q[ib32].val [3 ] = vmull_s8 (scale_h, vget_high_s8 (qx[ib32].val [1 ]));
3361+ max_i16 = vmaxq_u16 (max_i16, vmaxq_u16 (vabsq_s16 (q[ib32].val [0 ]), vabsq_s16 (q[ib32].val [1 ])));
3362+ max_i16 = vmaxq_u16 (max_i16, vmaxq_u16 (vabsq_s16 (q[ib32].val [2 ]), vabsq_s16 (q[ib32].val [3 ])));
3363+ }
3364+ uint16_t imax = vmaxvq_u16 (max_i16);
3365+ if (!imax) {
3366+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) for (int l = 0 ; l < 8 ; ++l) q8_k[64 *ib32 + 8 *l] = 0 ;
3367+ return 0 .f ;
3368+ }
3369+ float dnew = float (imax) * d0;
3370+ // auto max_u32 = vmaxq_u32(vmovl_u16(vget_low_u16(max_i16)), vmovl_u16(vget_high_u16(max_i16)));
3371+ // auto max_f32 = vcvtq_f32_u32(max_u32);
3372+ // auto dnew = vmaxvq_f32(max_f32) * d0;
3373+ bool needs_scaling = true ;
3374+ if (dnew <= 1 .f ) {
3375+ dnew = 1 .f ; needs_scaling = false ;
3376+ }
3377+ auto scale = vdupq_n_f32 (1 /dnew);
3378+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
3379+ if (needs_scaling) {
3380+ for (int l = 0 ; l < 4 ; ++l) {
3381+ auto i1 = vcvtnq_s32_f32 (vmulq_f32 (scale, vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (q[ib32].val [l])))));
3382+ auto i2 = vcvtnq_s32_f32 (vmulq_f32 (scale, vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (q[ib32].val [l])))));
3383+ q[ib32].val [l] = vcombine_s16 (vmovn_s32 (i1), vmovn_s32 (i2));
3384+ }
3385+ }
3386+ for (int l = 0 ; l < 2 ; ++l) {
3387+ auto s8 = vcombine_s8 (vmovn_s16 (q[ib32].val [2 *l+0 ]), vmovn_s16 (q[ib32].val [2 *l+1 ]));
3388+ vst1q_s8 ((int8_t *)block + 16 *l, s8);
3389+ }
3390+ auto qb = q8_k + 64 *ib32;
3391+ for (int l = 0 ; l < 8 ; ++l) {
3392+ qb[8 *l] = block[l];
3393+ }
3394+ }
3395+ return dnew;
3396+ }
3397+
3398+ void iqk_convert_iq2_xxs_q8_k_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
3399+ GGML_ASSERT (n%QK_K == 0 );
3400+ GGML_ASSERT (nrc_x%8 == 0 );
3401+
3402+ int nb = n/QK_K;
3403+
3404+ const block_iq2_xxs * x8[8 ];
3405+
3406+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
3407+
3408+ int8_t ls[16 ];
3409+ uint32_t block[8 ];
3410+ uint32_t aux32[2 ];
3411+ const uint8_t * aux8 = (const uint8_t *)aux32;
3412+
3413+ int8x16x2_t xv[8 ];
3414+
3415+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
3416+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq2_xxs *)((const char *)vx + (ix + k)*bx);
3417+ for (int i = 0 ; i < nb; ++i) {
3418+ for (int k = 0 ; k < 8 ; ++k) {
3419+ float d = 0 .125f * GGML_FP16_TO_FP32 (x8[k][i].d );
3420+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
3421+ std::memcpy (aux32, x8[k][i].qs + 4 *ib32, 2 *sizeof (uint32_t ));
3422+ ls[2 *ib32+0 ] = ls[2 *ib32+1 ] = (2 *(aux32[1 ] >> 28 ) + 1 );
3423+ xv[ib32].val [0 ] = vreinterpretq_s8_u64 (uint64x2_t {iq2xxs_grid[aux8[0 ]], iq2xxs_grid[aux8[1 ]]});
3424+ xv[ib32].val [1 ] = vreinterpretq_s8_u64 (uint64x2_t {iq2xxs_grid[aux8[2 ]], iq2xxs_grid[aux8[3 ]]});
3425+ apply_signs_2 ((uint8x16_t *)xv[ib32].val , keven_signs, aux32[1 ]);
3426+ }
3427+ float dnew = convert_to_q8_k_r8 (1 .f /124 , xv, ls, block, (uint32_t *)y[i].qs + k);
3428+ y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
3429+ }
3430+ }
3431+ y += nb;
3432+ }
3433+ }
3434+
3435+ void iqk_convert_iq2_xs_q8_k_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
3436+ GGML_ASSERT (n%QK_K == 0 );
3437+ GGML_ASSERT (nrc_x%8 == 0 );
3438+
3439+ int nb = n/QK_K;
3440+
3441+ const block_iq2_xs * x8[8 ];
3442+
3443+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
3444+
3445+ uint32_t block[8 ];
3446+
3447+ int8x16x2_t xv[8 ];
3448+
3449+ union { int8x16_t vec; int8_t val[16 ]; } helper;
3450+
3451+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
3452+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq2_xs *)((const char *)vx + (ix + k)*bx);
3453+ for (int i = 0 ; i < nb; ++i) {
3454+ for (int k = 0 ; k < 8 ; ++k) {
3455+ float d = 0 .125f * GGML_FP16_TO_FP32 (x8[k][i].d );
3456+ auto aux = vld1_u8 (x8[k][i].scales );
3457+ auto scales_l = vand_u8 (aux, vdup_n_u8 (0xf ));
3458+ auto scales_h = vshr_n_u8 (aux, 4 );
3459+ auto aux1 = vcombine_u8 (vzip1_u8 (scales_l, scales_h), vzip2_u8 (scales_l, scales_h));
3460+ helper.vec = vreinterpretq_s8_u8 (vorrq_u8 (vshlq_n_u8 (aux1, 1 ), vdupq_n_u8 (1 )));
3461+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
3462+ xv[ib32].val [0 ] = vreinterpretq_s8_u64 (uint64x2_t {iq2xs_grid[x8[k][i].qs [4 *ib32+0 ] & 511 ], iq2xs_grid[x8[k][i].qs [4 *ib32+1 ] & 511 ]});
3463+ xv[ib32].val [1 ] = vreinterpretq_s8_u64 (uint64x2_t {iq2xs_grid[x8[k][i].qs [4 *ib32+2 ] & 511 ], iq2xs_grid[x8[k][i].qs [4 *ib32+3 ] & 511 ]});
3464+ auto s1 = vreinterpretq_s8_u64 (uint64x2_t {keven_signs[x8[k][i].qs [4 *ib32+0 ] >> 9 ], keven_signs[x8[k][i].qs [4 *ib32+1 ] >> 9 ]});
3465+ auto s2 = vreinterpretq_s8_u64 (uint64x2_t {keven_signs[x8[k][i].qs [4 *ib32+2 ] >> 9 ], keven_signs[x8[k][i].qs [4 *ib32+3 ] >> 9 ]});
3466+ xv[ib32].val [0 ] = vmulq_s8 (xv[ib32].val [0 ], s1);
3467+ xv[ib32].val [1 ] = vmulq_s8 (xv[ib32].val [1 ], s2);
3468+ }
3469+ float dnew = convert_to_q8_k_r8 (1 .f /124 , xv, helper.val , block, (uint32_t *)y[i].qs + k);
3470+ y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
3471+ }
3472+ }
3473+ y += nb;
3474+ }
3475+ }
3476+
3477+ void iqk_convert_iq2_s_q8_k_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
3478+ GGML_ASSERT (n%QK_K == 0 );
3479+ GGML_ASSERT (nrc_x%8 == 0 );
3480+
3481+ int nb = n/QK_K;
3482+
3483+ const block_iq2_s * x8[8 ];
3484+
3485+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
3486+
3487+ uint32_t block[8 ];
3488+
3489+ union { int8x16_t vec; int8_t val[16 ]; } helper;
3490+ int8x16x2_t xv[8 ];
3491+
3492+ SignHelper sh;
3493+
3494+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
3495+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq2_s *)((const char *)vx + (ix + k)*bx);
3496+ for (int i = 0 ; i < nb; ++i) {
3497+ for (int k = 0 ; k < 8 ; ++k) {
3498+ float d = 0 .125f * GGML_FP16_TO_FP32 (x8[k][i].d );
3499+ auto aux = vld1_u8 (x8[k][i].scales );
3500+ auto scales_l = vand_u8 (aux, vdup_n_u8 (0xf ));
3501+ auto scales_h = vshr_n_u8 (aux, 4 );
3502+ auto aux1 = vcombine_u8 (vzip1_u8 (scales_l, scales_h), vzip2_u8 (scales_l, scales_h));
3503+ helper.vec = vreinterpretq_s8_u8 (vorrq_u8 (vshlq_n_u8 (aux1, 1 ), vdupq_n_u8 (1 )));
3504+ for (int j = 0 ; j < 2 ; ++j) {
3505+ auto qs = x8[k][i].qs + 16 *j;
3506+ auto qh = x8[k][i].qh + 4 *j;
3507+ auto signs16 = vld1q_u8 (qs + QK_K/8 );
3508+ sh.init ();
3509+ DequantizerIQ2S::make4 (sh, signs16, qs+0 , qh+0 , (uint8x16_t *)&xv[4 *j+0 ]);
3510+ DequantizerIQ2S::make4 (sh, signs16, qs+8 , qh+2 , (uint8x16_t *)&xv[4 *j+2 ]);
3511+ }
3512+ float dnew = convert_to_q8_k_r8 (1 .f /124 , xv, helper.val , block, (uint32_t *)y[i].qs + k);
3513+ y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
3514+ }
3515+ }
3516+ y += nb;
3517+ }
3518+ }
3519+
3520+ void iqk_convert_iq3_xxs_q8_k_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
3521+ GGML_ASSERT (n%QK_K == 0 );
3522+ GGML_ASSERT (nrc_x%8 == 0 );
3523+
3524+ int nb = n/QK_K;
3525+
3526+ const block_iq3_xxs * x8[8 ];
3527+
3528+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
3529+
3530+ int8_t ls[16 ];
3531+ int8x16x2_t xv[8 ];
3532+ uint32_t block[8 ];
3533+ uint32_t aux32;
3534+
3535+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
3536+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq3_xxs *)((const char *)vx + (ix + k)*bx);
3537+ for (int i = 0 ; i < nb; ++i) {
3538+ for (int k = 0 ; k < 8 ; ++k) {
3539+ float d = 0 .25f * GGML_FP16_TO_FP32 (x8[k][i].d );
3540+ auto qs = x8[k][i].qs ;
3541+ auto sas = qs + QK_K/4 ;
3542+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
3543+ std::memcpy (&aux32, sas + 4 *ib32, sizeof (uint32_t ));
3544+ ls[2 *ib32 + 0 ] = ls[2 *ib32 + 1 ] = (2 *(aux32 >> 28 ) + 1 );
3545+ xv[ib32].val [0 ] = vreinterpretq_s8_u32 (uint32x4_t {iq3xxs_grid[qs[0 ]], iq3xxs_grid[qs[1 ]], iq3xxs_grid[qs[2 ]], iq3xxs_grid[qs[3 ]]});
3546+ xv[ib32].val [1 ] = vreinterpretq_s8_u32 (uint32x4_t {iq3xxs_grid[qs[4 ]], iq3xxs_grid[qs[5 ]], iq3xxs_grid[qs[6 ]], iq3xxs_grid[qs[7 ]]});
3547+ apply_signs_2 ((uint8x16_t *)xv[ib32].val , keven_signs, aux32);
3548+ qs += 8 ;
3549+ }
3550+ float dnew = convert_to_q8_k_r8 (1 .f /124 , xv, ls, block, (uint32_t *)y[i].qs + k);
3551+ y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
3552+ }
3553+ }
3554+ y += nb;
3555+ }
3556+ }
3557+
3558+ // struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
3559+ // DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
3560+ //
3561+ // constexpr static int num_blocks() { return 8; }
3562+ // constexpr static bool should_scale_quants() { return false; }
3563+ //
3564+ // template <typename Q8>
3565+ // inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) {
3566+ // d = GGML_FP16_TO_FP32(x[i].d);
3567+ // uint32_t scales32[2];
3568+ // std::memcpy(scales32, x[i].scales, 4);
3569+ // scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
3570+ // scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
3571+ // auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7
3572+ // scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400)));
3573+ // auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8));
3574+ // int32x4x2_t scales;
3575+ // scales.val[0] = vmovl_s16(vget_low_s16(scales16));
3576+ // scales.val[1] = vmovl_s16(vget_high_s16(scales16));
3577+ // return scales;
3578+ // }
3579+ //
3580+ // static inline void make2(SignHelper& sh, const uint8x16_t& signs16, const uint16x8_t& idx_l, uint8_t qh,
3581+ // const int8x16_t& hshift, uint8x16_t * b) {
3582+ // auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256)));
3583+ // const uint16_t * idx = (const uint16_t *)&vindex;
3584+ // b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});
3585+ // b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});
3586+ // sh.apply_signs_1(b+0, signs16);
3587+ // sh.apply_signs_1(b+1, signs16);
3588+ // }
3589+ // static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh,
3590+ // const int8x16_t& hshift, uint8x16_t * b) {
3591+ // auto idx_l = vld1q_u8(qs);
3592+ // make2(sh, signs16, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0);
3593+ // make2(sh, signs16, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2);
3594+ // }
3595+ //
3596+ // inline void prepare(int i, int j) {
3597+ //
3598+ // static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
3599+ // const auto hshift = vld1q_s16(k_shift);
3600+ //
3601+ // const auto * qs = x[i].qs + 32*j;
3602+ // const auto * qh = x[i].qh + 4*j;
3603+ // const auto signs16 = vld1q_u8(x[i].signs + 16*j);
3604+ //
3605+ // sh.init();
3606+ // make4(sh, signs16, qs+ 0, qh+0, hshift, bits.b1.val);
3607+ // make4(sh, signs16, qs+16, qh+2, hshift, bits.b2.val);
3608+ // }
3609+ //
3610+ // SimpleBits bits;
3611+ // SignHelper sh;
3612+ // uint32x4x2_t gas;
3613+ //
3614+ // };
3615+
3616+ void iqk_convert_iq3_s_q8_k_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
3617+ GGML_ASSERT (n%QK_K == 0 );
3618+ GGML_ASSERT (nrc_x%8 == 0 );
3619+
3620+ int nb = n/QK_K;
3621+
3622+ const block_iq3_s * x8[8 ];
3623+
3624+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
3625+
3626+ int8_t ls[16 ];
3627+ SignHelper sh;
3628+
3629+ uint32_t block[8 ];
3630+ int8x16x2_t xv[8 ];
3631+
3632+ static const int16_t k_shift[8 ] = {8 , 7 , 6 , 5 , 4 , 3 , 2 , 1 };
3633+ const auto hshift = vld1q_s16 (k_shift);
3634+
3635+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
3636+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq3_s *)((const char *)vx + (ix + k)*bx);
3637+ for (int i = 0 ; i < nb; ++i) {
3638+ for (int k = 0 ; k < 8 ; ++k) {
3639+ float d = GGML_FP16_TO_FP32 (x8[k][i].d );
3640+ for (int j = 0 ; j < 2 ; ++j) {
3641+ const auto * qs = x8[k][i].qs + 32 *j;
3642+ const auto * qh = x8[k][i].qh + 4 *j;
3643+ const auto signs16 = vld1q_u8 (x8[k][i].signs + 16 *j);
3644+ sh.init ();
3645+ DequantizerIQ3S::make4 (sh, signs16, qs+ 0 , qh+0 , hshift, (uint8x16_t *)&xv[4 *j+0 ]);
3646+ DequantizerIQ3S::make4 (sh, signs16, qs+16 , qh+2 , hshift, (uint8x16_t *)&xv[4 *j+2 ]);
3647+ }
3648+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
3649+ ls[2 *ib32 + 0 ] = ls[2 *ib32 + 1 ] = (2 *((x8[k][i].scales [ib32/2 ] >> 4 *(ib32%2 )) & 0xf ) + 1 );
3650+ }
3651+ float dnew = convert_to_q8_k_r8 (1 .f /127 , xv, ls, block, (uint32_t *)y[i].qs + k);
3652+ y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
3653+ }
3654+ }
3655+ y += nb;
3656+ }
3657+ }
3658+
3659+
33513660}
33523661
33533662bool iqk_convert_iquants_q80_r8 ([[maybe_unused]] int type, int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, int nrc_x) {
33543663 if (n%QK_K != 0 || nrc_x%8 != 0 ) return false ;
3355- return false ;
3356- // switch (ggml_type(type)) {
3357- // case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
3358- // case GGML_TYPE_IQ2_XS : iqk_convert_iq2_xs_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3359- // case GGML_TYPE_IQ2_S : iqk_convert_iq2_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3360- // case GGML_TYPE_IQ3_XXS: iqk_convert_iq3_xxs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
3361- // case GGML_TYPE_IQ3_S : iqk_convert_iq3_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3362- // default: return false;
3363- // }
3364- // return true;
3664+ switch (ggml_type (type)) {
3665+ case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_k_r8 (n, vx, bx, vy, nrc_x); break ;
3666+ case GGML_TYPE_IQ2_XS : iqk_convert_iq2_xs_q8_k_r8 (n, vx, bx, vy, nrc_x); break ;
3667+ case GGML_TYPE_IQ2_S : iqk_convert_iq2_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break ;
3668+ case GGML_TYPE_IQ3_XXS: iqk_convert_iq3_xxs_q8_k_r8 (n, vx, bx, vy, nrc_x); break ;
3669+ case GGML_TYPE_IQ3_S : iqk_convert_iq3_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break ;
3670+ default : return false ;
3671+ }
3672+ return true ;
33653673}
33663674
33673675bool iqk_set_kernels_iquants (int ne00, int typeA, int typeB, std::array<mul_mat_t , IQK_MAX_NY>& kernels, mul_mat_t & func16) {
0 commit comments