@@ -2782,29 +2782,247 @@ void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
27822782 }
27832783}
27842784
2785+ typedef struct {
2786+ ggml_half d[16 ];
2787+ int8_t qs[256 ];
2788+ } block_q8_1_r8;
2789+
2790+ template <int nrc_y>
2791+ void mul_mat_q8_1_r8_q8_1 (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
2792+ GGML_ASSERT (nrc_x%8 == 0 );
2793+ Q8<nrc_y, block_q8_1_x4> q8 (info);
2794+ int nb = n / QK8_0;
2795+ float32x4_t acc[2 *nrc_y] = {};
2796+ int8x16_t qx[16 ];
2797+ float d8[8 *nrc_y];
2798+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
2799+ const block_q8_1_r8 * iq8 = (const block_q8_1_r8 *)((const char *)vx + ix*bx);
2800+ for (int ib4 = 0 ; ib4 < nb/4 ; ++ib4) {
2801+ for (int iy = 0 ; iy < nrc_y; ++iy) {
2802+ vst1q_f32 (d8+8 *iy+0 , vcvt_f32_f16 (vld1_f16 ((const float16_t *)q8.y [iy][ib4].d +0 )));
2803+ vst1q_f32 (d8+8 *iy+4 , vcvt_f32_f16 (vld1_f16 ((const float16_t *)q8.y [iy][ib4].d +4 )));
2804+ }
2805+ for (int k = 0 ; k < 4 ; ++k) {
2806+ auto scales16 = vld1q_f16 ((const float16_t *)iq8[4 *ib4+k].d );
2807+ auto scales1 = vcvt_f32_f16 (vget_low_f16 (scales16));
2808+ auto scales2 = vcvt_f32_f16 (vget_high_f16 (scales16));
2809+ auto m16 = vld1q_f16 ((const float16_t *)iq8[4 *ib4+k].d +8 );
2810+ auto m1 = vcvt_f32_f16 (vget_low_f16 (m16));
2811+ auto m2 = vcvt_f32_f16 (vget_high_f16 (m16));
2812+ for (int j = 0 ; j < 16 ; ++j) qx[j] = vld1q_s8 (iq8[4 *ib4+k].qs + 16 *j);
2813+ int32x4_t sumi1, sumi2;
2814+ for (int iy = 0 ; iy < nrc_y; ++iy) {
2815+ qx_0_q8_0_dot (qx, q8.y [iy][ib4].qs +32 *k, sumi1, sumi2);
2816+ auto dy = vdupq_n_f32 (d8[8 *iy+k]);
2817+ acc[2 *iy+0 ] = vfmaq_f32 (acc[2 *iy+0 ], vmulq_f32 (scales1, dy), vcvtq_f32_s32 (sumi1));
2818+ acc[2 *iy+1 ] = vfmaq_f32 (acc[2 *iy+1 ], vmulq_f32 (scales2, dy), vcvtq_f32_s32 (sumi2));
2819+ auto my = vdupq_n_f32 (d8[8 *iy+k+4 ]);
2820+ acc[2 *iy+0 ] = vfmaq_f32 (acc[2 *iy+0 ], m1, my);
2821+ acc[2 *iy+1 ] = vfmaq_f32 (acc[2 *iy+1 ], m2, my);
2822+ }
2823+ }
2824+ }
2825+ for (int ib = 4 *(nb/4 ); ib < nb; ++ib) {
2826+ auto scales16 = vld1q_f16 ((const float16_t *)iq8[ib].d );
2827+ auto scales1 = vcvt_f32_f16 (vget_low_f16 (scales16));
2828+ auto scales2 = vcvt_f32_f16 (vget_high_f16 (scales16));
2829+ auto m16 = vld1q_f16 ((const float16_t *)iq8[ib].d +8 );
2830+ auto m1 = vcvt_f32_f16 (vget_low_f16 (m16));
2831+ auto m2 = vcvt_f32_f16 (vget_high_f16 (m16));
2832+ for (int j = 0 ; j < 16 ; ++j) qx[j] = vld1q_s8 (iq8[ib].qs + 16 *j);
2833+ int32x4_t sumi1, sumi2;
2834+ for (int iy = 0 ; iy < nrc_y; ++iy) {
2835+ auto qy = (const block_q8_1 *)q8.y [iy];
2836+ qx_0_q8_0_dot (qx, qy[ib].qs , sumi1, sumi2);
2837+ auto dy = vdupq_n_f32 (GGML_FP16_TO_FP32 (qy[ib].d ));
2838+ acc[2 *iy+0 ] = vfmaq_f32 (acc[2 *iy+0 ], vmulq_f32 (scales1, dy), vcvtq_f32_s32 (sumi1));
2839+ acc[2 *iy+1 ] = vfmaq_f32 (acc[2 *iy+1 ], vmulq_f32 (scales2, dy), vcvtq_f32_s32 (sumi2));
2840+ auto my = vdupq_n_f32 (GGML_FP16_TO_FP32 (qy[ib].s ));
2841+ acc[2 *iy+0 ] = vfmaq_f32 (acc[2 *iy+0 ], m1, my);
2842+ acc[2 *iy+1 ] = vfmaq_f32 (acc[2 *iy+1 ], m2, my);
2843+ }
2844+ }
2845+ for (int iy = 0 ; iy < nrc_y; ++iy) {
2846+ info.store (ix+0 , iy, acc[2 *iy+0 ]);
2847+ info.store (ix+4 , iy, acc[2 *iy+1 ]);
2848+ acc[2 *iy] = acc[2 *iy+1 ] = vdupq_n_f32 (0 .f );
2849+ }
2850+ }
27852851}
27862852
2787- bool iqk_convert_legacy_quants_q8_r8 ([[maybe_unused]] int type, [[maybe_unused]] int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, [[maybe_unused]] int nrc_x) {
2788- return false ;
2789- // switch (type) {
2790- // case GGML_TYPE_Q4_0 : iqk_convert_qX_q80_r8<block_q4_0, Q4_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
2791- // case GGML_TYPE_Q4_1 : iqk_convert_qX_1_q8_1_r8<block_q4_1, Q4_1_Dequantizer>(n, vx, bx, vy, nrc_x); break;
2792- // case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8<block_q5_0, Q5_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
2793- // case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8<block_q5_1, Q5_1_Dequantizer<block_q5_1>>(n, vx, bx, vy, nrc_x); break;
2794- // case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, Q6_0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
2795- // case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, IQ4_NL0_Dequantizer>(n, vx, bx, vy, nrc_x); break;
2796- // case GGML_TYPE_Q8_0 : iqk_convert_q80_q80_r8(n, vx, bx, vy, nrc_x); break;
2797- // default: return false;
2798- // }
2799- // return true;
2853+ struct DeqQ40 {
2854+ const int8x16_t m8 = vdupq_n_s8(-8 );
2855+ const uint8x16_t ml = vdupq_n_s8(0xf );
2856+ inline int8x16x2_t dequant (const block_q4_0& x) const {
2857+ auto bits = vld1q_u8 (x.qs );
2858+ return { vaddq_s8 (vreinterpretq_s8_u8 (vandq_u8 (bits, ml)), m8), vaddq_s8 (vreinterpretq_s8_u8 (vshrq_n_u8 (bits, 4 )), m8) };
2859+ }
2860+ };
2861+
2862+ struct DeqQ41 {
2863+ const uint8x16_t ml = vdupq_n_s8(0xf );
2864+ inline int8x16x2_t dequant (const block_q4_1& x) const {
2865+ auto bits = vld1q_u8 (x.qs );
2866+ return { vreinterpretq_s8_u8 (vandq_u8 (bits, ml)), vreinterpretq_s8_u8 (vshrq_n_u8 (bits, 4 )) };
2867+ }
2868+ };
2869+
2870+ struct DeqIQ4NL {
2871+ const int8x16_t mt = load_values();
2872+ const uint8x16_t ml = vdupq_n_s8(0xf );
2873+ inline int8x16x2_t dequant (const block_iq4_nl& x) const {
2874+ auto bits = vld1q_u8 (x.qs );
2875+ return { vqtbl1q_s8 (mt, vandq_u8 (bits, ml)), vqtbl1q_s8 (mt, vshrq_n_u8 (bits, 4 )) };
2876+ }
2877+ static inline int8x16_t load_values () { return vld1q_s8 (iq4k_values); }
2878+ };
2879+
2880+ struct DeqQ50 {
2881+
2882+ inline int8x16x2_t dequant (const block_q5_0& x) const {
2883+ int8x16x2_t r;
2884+ bits.prepare1 (x.qs , r.val );
2885+ auto qh = x.qh ;
2886+ r.val [0 ] = vreinterpretq_s8_u8 (vorrq_u8 (vreinterpretq_u8_s8 (r.val [0 ]), vandq_u8 (mh, hbits.to_negated_bytes (qh+0 ))));
2887+ r.val [1 ] = vreinterpretq_s8_u8 (vorrq_u8 (vreinterpretq_u8_s8 (r.val [1 ]), vandq_u8 (mh, hbits.to_negated_bytes (qh+2 ))));
2888+ return r;
2889+ }
2890+
2891+ Q4LegacyBits bits;
2892+ HighBit5Legacy hbits;
2893+ const uint8x16_t mh = vdupq_n_u8(0xf0 );
2894+ };
2895+
2896+ struct DeqQ51 {
2897+
2898+ inline int8x16x2_t dequant (const block_q5_1& x) const {
2899+ int8x16x2_t r;
2900+ bits.prepare1 (x.qs , r.val );
2901+ auto qh = x.qh ;
2902+ r.val [0 ] = vreinterpretq_s8_u8 (vorrq_u8 (vreinterpretq_u8_s8 (r.val [0 ]), vandq_u8 (mh, hbits.to_bytes (qh+0 ))));
2903+ r.val [1 ] = vreinterpretq_s8_u8 (vorrq_u8 (vreinterpretq_u8_s8 (r.val [1 ]), vandq_u8 (mh, hbits.to_bytes (qh+2 ))));
2904+ return r;
2905+ }
2906+
2907+ Q4LegacyBits bits;
2908+ HighBit5Legacy hbits;
2909+ const uint8x16_t mh = vdupq_n_u8(0x10 );
2910+ };
2911+
2912+ struct DeqQ60 {
2913+
2914+ inline int8x16x2_t dequant (const block_q6_0& x) const {
2915+ int8x16x2_t r;
2916+ bits.prepare1 (x.qs , r.val );
2917+ auto qh8 = vld1_u8 (x.qh );
2918+ auto qh = vcombine_u8 (vshl_n_u8 (qh8, 4 ), qh8);
2919+ r.val [0 ] = vaddq_s8 (vorrq_u8 (r.val [0 ], vandq_u8 (qh, hmask)), m32);
2920+ r.val [1 ] = vaddq_s8 (vorrq_u8 (r.val [1 ], vandq_u8 (vshrq_n_u8 (qh, 2 ), hmask)), m32);
2921+ return r;
2922+ }
2923+
2924+ Q4LegacyBits bits;
2925+ const int8x16_t m32 = vdupq_n_s8(-32 );
2926+ const uint8x16_t hmask = vdupq_n_u8(0x30 );
2927+ };
2928+
2929+ struct DeqQ80 {
2930+ inline int8x16x2_t dequant (const block_q8_0& x) const {
2931+ return vld1q_s8_x2 (x.qs );
2932+ }
2933+ };
2934+
2935+ template <typename Block, typename Dequantizer>
2936+ void iqk_convert_qX_q80_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
2937+ GGML_ASSERT (n%QK4_0 == 0 );
2938+ GGML_ASSERT (nrc_x%8 == 0 );
2939+
2940+ const int nb = n/QK8_0;
2941+
2942+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
2943+
2944+ const Block * x8[8 ];
2945+
2946+ uint32_t block[8 ];
2947+
2948+ Dequantizer deq;
2949+
2950+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
2951+
2952+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const Block *)((const char *)vx + (ix + k)*bx);
2953+
2954+ for (int i = 0 ; i < nb; ++i) {
2955+ for (int k = 0 ; k < 8 ; ++k) {
2956+ y[i].d [k] = x8[k][i].d ;
2957+ vst1q_s8_x2 ((int8_t *)block, deq.dequant (x8[k][i]));
2958+ auto qs = (uint32_t *)y[i].qs ;
2959+ for (int l = 0 ; l < 4 ; ++l) {
2960+ qs[8 *l + k + 0 ] = block[l + 0 ];
2961+ qs[8 *l + k + 32 ] = block[l + 4 ];
2962+ }
2963+ }
2964+ }
2965+ y += nb;
2966+ }
2967+ }
2968+
2969+ template <typename Block, typename Dequantizer>
2970+ void iqk_convert_qX_1_q8_1_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
2971+ GGML_ASSERT (n%QK4_0 == 0 );
2972+ GGML_ASSERT (nrc_x%8 == 0 );
2973+
2974+ const int nb = n/QK8_0;
2975+
2976+ block_q8_1_r8 * y = (block_q8_1_r8 *)vy;
2977+
2978+ const Block * x8[8 ];
2979+
2980+ uint32_t block[8 ];
2981+
2982+ Dequantizer deq;
2983+
2984+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
2985+
2986+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const Block *)((const char *)vx + (ix + k)*bx);
2987+
2988+ for (int i = 0 ; i < nb; ++i) {
2989+ for (int k = 0 ; k < 8 ; ++k) {
2990+ y[i].d [k+0 ] = x8[k][i].d ;
2991+ y[i].d [k+8 ] = x8[k][i].m ;
2992+ vst1q_s8_x2 ((int8_t *)block, deq.dequant (x8[k][i]));
2993+ auto qs = (uint32_t *)y[i].qs ;
2994+ for (int l = 0 ; l < 4 ; ++l) {
2995+ qs[8 *l + k + 0 ] = block[l + 0 ];
2996+ qs[8 *l + k + 32 ] = block[l + 4 ];
2997+ }
2998+ }
2999+ }
3000+ y += nb;
3001+ }
3002+ }
3003+
3004+ }
3005+
3006+ bool iqk_convert_legacy_quants_q8_r8 (int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
3007+ switch (type) {
3008+ case GGML_TYPE_Q4_0 : iqk_convert_qX_q80_r8<block_q4_0, DeqQ40>(n, vx, bx, vy, nrc_x); break ;
3009+ case GGML_TYPE_Q4_1 : iqk_convert_qX_1_q8_1_r8<block_q4_1, DeqQ41>(n, vx, bx, vy, nrc_x); break ;
3010+ case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8<block_q5_0, DeqQ50>(n, vx, bx, vy, nrc_x); break ;
3011+ case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8<block_q5_1, DeqQ51>(n, vx, bx, vy, nrc_x); break ;
3012+ case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, DeqQ60>(n, vx, bx, vy, nrc_x); break ;
3013+ case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, DeqIQ4NL>(n, vx, bx, vy, nrc_x); break ;
3014+ case GGML_TYPE_Q8_0 : iqk_convert_qX_q80_r8<block_q8_0, DeqQ80>(n, vx, bx, vy, nrc_x); break ;
3015+ default : return false ;
3016+ }
3017+ return true ;
28003018}
28013019
28023020bool iqk_set_kernels_legacy_quants (int ne00, int typeA, int typeB, std::array<mul_mat_t , IQK_MAX_NY>& kernels, mul_mat_t & func16) {
28033021
28043022 if (ne00%QK8_0 != 0 ) return false ;
28053023
28063024 auto etypeA = ggml_type (typeA);
2807- auto expected_typeB = etypeA == GGML_TYPE_Q4_1 || etypeA == GGML_TYPE_Q5_1 ? GGML_TYPE_Q8_1_X4 : GGML_TYPE_Q8_0_X4;
3025+ auto expected_typeB = etypeA == GGML_TYPE_Q4_1 || etypeA == GGML_TYPE_Q5_1 || etypeA == GGML_TYPE_Q8_1 ? GGML_TYPE_Q8_1_X4 : GGML_TYPE_Q8_0_X4;
28083026 if (ggml_type (typeB) != expected_typeB) return false ;
28093027
28103028 func16 = nullptr ;
@@ -2843,6 +3061,9 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
28433061 case GGML_TYPE_Q8_0_R8:
28443062 IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_q8_0_r8_q8_0, kernels);
28453063 break ;
3064+ case GGML_TYPE_Q8_1:
3065+ IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_q8_1_r8_q8_1, kernels);
3066+ break ;
28463067 case GGML_TYPE_IQ4_NL_R4:
28473068 IQK_SET_MUL_MAT_FUNCTIONS_T (mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer, kernels);
28483069 break ;
0 commit comments