@@ -79,6 +79,16 @@ template <typename Q8, typename Q8x4> struct Sum4q4 {
7979 inline __m256i compute (__m256i x, __m256i y) const { return _mm256_madd_epi16 (_mm256_set1_epi16 (1 ), _mm256_maddubs_epi16 (x, y)); }
8080};
8181
82+ inline __m256 convert_scales (const uint16_t * scales) {
83+ auto aux_d = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)scales)), 16 ));
84+ auto aux_m = _mm_cvtepi32_ps (_mm_cvtepi16_epi32 (_mm_loadl_epi64 ((const __m128i *)(scales+4 ))));
85+ return _mm256_set_m128 (_mm_mul_ps (aux_d, aux_m), aux_d);
86+ }
87+
88+ inline __m128 convert_scales_s (const uint16_t * scales) {
89+ return _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)scales)), 16 ));
90+ }
91+
8292struct ScaleHelperQ8_0 {
8393 inline __m128 prepare4 (const block_q8_0 * y) {
8494 const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y;
@@ -106,6 +116,20 @@ struct ScaleHelperQ_0 {
106116 template <typename Q> inline float prepare1 (float d, const Q * y) const { return d*prepare1 (y); }
107117};
108118
119+ struct ScaleHelperQ8_2S {
120+ template <typename Q>
121+ inline __m128 prepare4 (const Q * y) {
122+ const block_q8_2_x4 * y4 = (const block_q8_2_x4 *)y;
123+ return convert_scales_s ((const uint16_t *)y4->d );
124+ }
125+ template <typename Q>
126+ inline __m128 prepare4 (__m128 other_scales, const Q * y) {
127+ return _mm_mul_ps (other_scales, prepare4<Q>(y));
128+ }
129+ template <typename Q> static inline float prepare1 (const Q * y) { return GGML_BF16_TO_FP32 (ggml_bf16_t {y->d }); }
130+ template <typename Q> static inline float prepare1 (float d, const Q * y) { return d*prepare1 (y); }
131+ };
132+
109133struct ScaleHelperQ_0_MXFP4 {
110134 float scales[4 ];
111135 template <typename Q>
@@ -188,12 +212,6 @@ struct ScaleHelperQ8_1 {
188212 }
189213};
190214
191- inline __m256 convert_scales (const uint16_t * scales) {
192- auto aux_d = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)scales)), 16 ));
193- auto aux_m = _mm_cvtepi32_ps (_mm_cvtepi16_epi32 (_mm_loadl_epi64 ((const __m128i *)(scales+4 ))));
194- return _mm256_set_m128 (_mm_mul_ps (aux_d, aux_m), aux_d);
195- }
196-
197215struct ScaleHelperQ8_2 {
198216 template <typename Q>
199217 inline __m256 prepare4 (const Q * y) {
@@ -348,6 +366,7 @@ using AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>;
348366
349367using Sum4TypeQ80 = Sum4<block_q8_0, block_q8_0_x4, SignedDot, false >;
350368using Sum4TypeQ82 = Sum4<block_q8_2, block_q8_2_x4, UnsignedDot, false >;
369+ using Sum4TypeQ82S = Sum4<block_q8_2, block_q8_2_x4, SignedDot, false >;
351370
352371template <typename Unpacker, typename AccumType, typename Scales, typename Q8, int nrc_y>
353372void mul_mat_qX_q8_Helper (int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) {
@@ -374,10 +393,35 @@ void mul_mat_qX_q8_Helper_x2(int nb, const void * vx, size_t bx, const DataInfo&
374393 }
375394}
376395
377- template <typename Unpacker, int nrc_y>
396+ template <typename Unpacker, int nrc_y, typename Block = block_q8_0 >
378397void mul_mat_qX_0_q8_0_T (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
379398 assert (n%Unpacker::block_size () == 0 );
380- Q8<nrc_y, block_q8_0> q8 (info);
399+ Q8<nrc_y, Block> q8 (info);
400+ int nb = n/Unpacker::block_size ();
401+ if constexpr (std::is_same_v<Block, block_q8_2>) {
402+ if (nb%4 == 0 ) {
403+ mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, true >, ScaleHelperQ8_2S, Block, nrc_y>(
404+ nb, vx, bx, info, q8.y , nrc_x);
405+ } else {
406+ mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, false >, ScaleHelperQ8_2S, Block, nrc_y>(
407+ nb, vx, bx, info, q8.y , nrc_x);
408+ }
409+ }
410+ else {
411+ if (nb%4 == 0 ) {
412+ mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, true >, ScaleHelperQ8_0, Block, nrc_y>(
413+ nb, vx, bx, info, q8.y , nrc_x);
414+ } else {
415+ mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, false >, ScaleHelperQ8_0, Block, nrc_y>(
416+ nb, vx, bx, info, q8.y , nrc_x);
417+ }
418+ }
419+ }
420+
421+ template <typename Unpacker, int nrc_y>
422+ void mul_mat_qX_0_q8_2_T (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
423+ assert (n%Unpacker::block_size () == 0 );
424+ Q8<nrc_y, block_q8_2> q8 (info);
381425 int nb = n/Unpacker::block_size ();
382426 if (nb%4 == 0 ) {
383427 mul_mat_qX_q8_Helper<Unpacker, AccumType0<nrc_y, true >, ScaleHelperQ8_0, block_q8_0, nrc_y>(
@@ -393,11 +437,11 @@ void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info
393437template <typename Unpacker, int nrc_y, int nrc_x>
394438void mul_mat_qX_0_q8_0_Tx (int n, const void * vx, size_t bx, const DataInfo& info, int ) {
395439 static_assert (8 %nrc_y == 0 );
396- Q8<nrc_y, block_q8_0 > q8 (info);
440+ Q8<nrc_y, block_q8_2 > q8 (info);
397441 int nb = n/Unpacker::block_size ();
398442 Unpacker unp (vx, bx);
399443 typename Unpacker::Sum4T sum4;
400- ScaleHelperQ8_0 scales;
444+ ScaleHelperQ8_2S scales;
401445 __m256 result[8 ];
402446 auto store = [&info, &result] (int ix0) {
403447 if constexpr (nrc_y == 1 ) {
@@ -549,19 +593,15 @@ struct Q4_0_1_Dequantizer {
549593 }
550594};
551595
552- struct IQ4_NL_Dequantizer {
596+ struct IQ4_NL_DequantizerU {
553597 Dequantizer4bit b4;
554- #ifdef HAVE_FANCY_SIMD
555598 const __m256i values = load_iq4nl_values_256();
556- #else
557- const __m256i values = load_iq4k_values_256();
558- #endif
559599 inline __m256i dequant (const block_iq4_nl * x) const {
560600 return _mm256_shuffle_epi8 (values, b4.dequant (x->qs ));
561601 }
562602};
563603
564- struct IQ4_NL0_Dequantizer {
604+ struct IQ4_NL_DequantizerS {
565605 Dequantizer4bit b4;
566606 const __m256i values = load_iq4k_values_256();
567607 inline __m256i dequant (const block_iq4_nl * x) const {
@@ -705,14 +745,19 @@ struct Q_Unpacker {
705745
706746struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
707747 Q8_0_Unpacker (const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
708- using Sum4T = Sum4TypeQ80 ;
748+ using Sum4T = Sum4TypeQ82S ;
709749 inline static int block_size () { return QK8_0; }
710750};
711751struct Q8_0_1_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0_1<127 >, Q8_0_1_Dequantizer> {
712752 Q8_0_1_Unpacker (const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
713753 using Sum4T = Sum4TypeQ82;
714754 inline static int block_size () { return QK8_0; }
715755};
756+ struct Q8_0_2_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
757+ Q8_0_2_Unpacker (const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
758+ using Sum4T = Sum4TypeQ82;
759+ inline static int block_size () { return QK8_0; }
760+ };
716761struct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_Dequantizer> {
717762 Q4_0_Unpacker (const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
718763 using Sum4T = Sum4TypeQ80;
@@ -729,19 +774,16 @@ struct MXFP4_Unpacker final : public Q_Unpacker<block_mxfp4, ScaleHelperQ_0_1_MX
729774 using Sum4T = Sum4TypeQ82;
730775 inline static int block_size () { return QK4_NL; }
731776};
732- #ifdef HAVE_FANCY_SIMD
733- struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0_1<128 >, IQ4_NL_Dequantizer> {
734- IQ4_NL_Unpacker (const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
777+ struct IQ4_NL_UnpackerU final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0_1<128 >, IQ4_NL_DequantizerU> {
778+ IQ4_NL_UnpackerU (const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
735779 using Sum4T = Sum4TypeQ82;
736780 inline static int block_size () { return QK4_NL; }
737781};
738- #else
739- struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0, IQ4_NL0_Dequantizer> {
740- IQ4_NL_Unpacker (const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
741- using Sum4T = Sum4TypeQ80;
782+ struct IQ4_NL_UnpackerS final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0, IQ4_NL_DequantizerS> {
783+ IQ4_NL_UnpackerS (const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
784+ using Sum4T = Sum4TypeQ82S;
742785 inline static int block_size () { return QK4_NL; }
743786};
744- #endif
745787struct Q5_0_Unpacker final : public Q_Unpacker<block_q5_0, ScaleHelperQ_0, Q5_0_Dequantizer> {
746788 Q5_0_Unpacker (const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
747789 using Sum4T = Sum4TypeQ80;
@@ -1872,19 +1914,20 @@ void iqk_convert_qX_1_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int
18721914}
18731915
18741916template <typename Dequantizer> void set_functions (std::array<mul_mat_t , IQK_MAX_NY>& funcs) {
1875- if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
1876- std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
1917+ if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker>) {
18771918 IQK_SET_MUL_MAT_FUNCTIONS_T (mul_mat_qX_0_q8_0_T, Dequantizer, funcs)
18781919 }
1920+ else if constexpr (std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
1921+ IQK_SET_MUL_MAT_FUNCTIONS_T2 (mul_mat_qX_0_q8_0_T, Dequantizer, block_q8_2, funcs)
1922+ }
18791923 else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker>) {
18801924 IQK_SET_MUL_MAT_FUNCTIONS_T (mul_mat_qX_1_q8_2_T, Dequantizer, funcs)
18811925 }
1882- else if constexpr (std::is_same_v<Dequantizer, IQ4_NL_Unpacker>) {
1883- #ifdef HAVE_FANCY_SIMD
1926+ else if constexpr (std::is_same_v<Dequantizer, IQ4_NL_UnpackerU>) {
18841927 IQK_SET_MUL_MAT_FUNCTIONS_T (mul_mat_qX_1_q8_2_T, Dequantizer, funcs)
1885- # else
1886- IQK_SET_MUL_MAT_FUNCTIONS_T (mul_mat_qX_0_q8_0_T, Dequantizer, funcs)
1887- # endif
1928+ }
1929+ else if constexpr (std::is_same_v< Dequantizer, IQ4_NL_UnpackerS>) {
1930+ IQK_SET_MUL_MAT_FUNCTIONS_T2 (mul_mat_qX_0_q8_0_T, Dequantizer, block_q8_2, funcs)
18881931 }
18891932 else if constexpr (std::is_same_v<Dequantizer, Q8_0_1_Unpacker> || std::is_same_v<Dequantizer, Q4_0_1_Unpacker> ||
18901933 std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, Q6_0_1_Unpacker> ||
@@ -1902,7 +1945,7 @@ bool iqk_convert_legacy_quants_q8_r8(int type, int n, const void * vx, size_t bx
19021945 case GGML_TYPE_Q5_0 : iqk_convert_qX_q80_r8<block_q5_0, Q5_0_Dequantizer>(n, vx, bx, vy, nrc_x); break ;
19031946 case GGML_TYPE_Q5_1 : iqk_convert_qX_1_q8_1_r8<block_q5_1, Q5_1_Dequantizer<block_q5_1>>(n, vx, bx, vy, nrc_x); break ;
19041947 case GGML_TYPE_Q6_0 : iqk_convert_qX_q80_r8<block_q6_0, Q6_0_Dequantizer>(n, vx, bx, vy, nrc_x); break ;
1905- case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, IQ4_NL0_Dequantizer >(n, vx, bx, vy, nrc_x); break ;
1948+ case GGML_TYPE_IQ4_NL: iqk_convert_qX_q80_r8<block_iq4_nl, IQ4_NL_DequantizerS >(n, vx, bx, vy, nrc_x); break ;
19061949 case GGML_TYPE_Q8_0 : iqk_convert_q80_q80_r8 (n, vx, bx, vy, nrc_x); break ;
19071950 case GGML_TYPE_MXFP4 : iqk_convert_qX_q80_r8<block_mxfp4, MXFP40_Dequantizer>(n, vx, bx, vy, nrc_x); break ;
19081951 default : return false ;
@@ -1939,20 +1982,17 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
19391982 set_functions<Q8_0_1_Unpacker>(kernels);
19401983#else
19411984 set_functions<Q8_0_Unpacker>(kernels);
1942- expected_typeB = GGML_TYPE_Q8_0_X4;
19431985#endif
19441986 break ;
19451987 case GGML_TYPE_IQ4_NL:
1946- set_functions<IQ4_NL_Unpacker>(kernels);
1947- #ifndef HAVE_FANCY_SIMD
1948- expected_typeB = GGML_TYPE_Q8_0_X4;
1988+ #ifdef HAVE_FANCY_SIMD
1989+ set_functions<IQ4_NL_UnpackerU>(kernels);
1990+ #else
1991+ set_functions<IQ4_NL_UnpackernS>(kernels);
19491992#endif
19501993 break ;
19511994 case GGML_TYPE_MXFP4:
19521995 set_functions<MXFP4_Unpacker>(kernels);
1953- // #ifndef HAVE_FANCY_SIMD
1954- // expected_typeB = GGML_TYPE_Q8_0_X4;
1955- // #endif
19561996 break ;
19571997 case GGML_TYPE_Q4_0_R8:
19581998 IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_q4_0_r8_q8_2, kernels)
@@ -3223,6 +3263,19 @@ inline std::pair<mul_mat_t, int> mul_mat_kernel(int int_typeA, int nq) {
32233263 case 7 : return std::make_pair (mul_mat, 7 >, 7 );\
32243264 }\
32253265 }
3266+ #define MAKE_FUNCS2 (mul_mat, block, n ) \
3267+ if (n >= kMaxQ ) return std::make_pair (mul_mat, kMaxQ , block>, kMaxQ );\
3268+ else {\
3269+ switch (n) {\
3270+ case 1 : return std::make_pair (mul_mat, 1 , block>, 1 );\
3271+ case 2 : return std::make_pair (mul_mat, 2 , block>, 2 );\
3272+ case 3 : return std::make_pair (mul_mat, 3 , block>, 3 );\
3273+ case 4 : return std::make_pair (mul_mat, 4 , block>, 4 );\
3274+ case 5 : return std::make_pair (mul_mat, 5 , block>, 5 );\
3275+ case 6 : return std::make_pair (mul_mat, 6 , block>, 6 );\
3276+ case 7 : return std::make_pair (mul_mat, 7 , block>, 7 );\
3277+ }\
3278+ }
32263279#define MAKE_FUNCS_ONLY_NRC (mul_mat, n ) \
32273280 if (n >= kMaxQ ) return std::make_pair (mul_mat<kMaxQ >, kMaxQ );\
32283281 else {\
@@ -3249,7 +3302,11 @@ inline std::pair<mul_mat_t, int> mul_mat_kernel(int int_typeA, int nq) {
32493302 if (nq == 1 ) return std::make_pair (mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 1 , k_step>, 1 );
32503303 if (nq == 2 ) return std::make_pair (mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 2 , k_step>, 2 );
32513304 if (nq == 4 ) return std::make_pair (mul_mat_qX_0_q8_0_Tx<Q8_0_Unpacker, 4 , k_step>, 4 );
3252- MAKE_FUNCS (mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, nq);
3305+ if (nq == 3 ) return std::make_pair (mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 3 , block_q8_2>, 3 );
3306+ if (nq == 5 ) return std::make_pair (mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 5 , block_q8_2>, 5 );
3307+ if (nq == 6 ) return std::make_pair (mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 6 , block_q8_2>, 6 );
3308+ if (nq == 7 ) return std::make_pair (mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 7 , block_q8_2>, 7 );
3309+ return std::make_pair (mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, kMaxQ , block_q8_2>, kMaxQ );
32533310#endif
32543311#endif
32553312 }
@@ -3293,9 +3350,9 @@ inline std::pair<mul_mat_t, int> mul_mat_kernel(int int_typeA, int nq) {
32933350 MAKE_FUNCS (mul_mat_qX_0_q8_0<DequantizerIQ4NL, nq);
32943351#else
32953352#ifdef HAVE_FANCY_SIMD
3296- MAKE_FUNCS (mul_mat_qX_1_q8_2_T<IQ4_NL_Unpacker , nq);
3353+ MAKE_FUNCS (mul_mat_qX_1_q8_2_T<IQ4_NL_UnpackerU , nq);
32973354#else
3298- MAKE_FUNCS (mul_mat_qX_0_q8_0_T<IQ4_NL_Unpacker , nq);
3355+ MAKE_FUNCS2 (mul_mat_qX_0_q8_0_T<IQ4_NL_UnpackerS, block_q8_2 , nq);
32993356#endif
33003357#endif
33013358 }
0 commit comments