@@ -170,6 +170,89 @@ struct Trellis3 {
170170 return _mm256_permutevar8x32_epi32 (aux[0 ], shuffle);
171171 }
172172 }
173+ IQK_ALWAYS_INLINE inline void next_128 (const uint32_t * val, __m256i * result) const {
174+ // Even though we only have 16 vector registers nn AVX2, this is still faster
175+ __m256i aux[16 ];
176+ auto perm = _mm256_setr_epi32 (0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 );
177+ for (int k = 0 ; k < 4 ; ++k) {
178+ auto v = _mm256_loadu_si256 ((const __m256i *)val + k);
179+ v = _mm256_permutevar8x32_epi32 (v, perm);
180+ aux[4 *k+0 ] = _mm256_shuffle_epi32 (v, 0x00 );
181+ aux[4 *k+1 ] = _mm256_shuffle_epi32 (v, 0x55 );
182+ aux[4 *k+2 ] = _mm256_shuffle_epi32 (v, 0xaa );
183+ aux[4 *k+3 ] = _mm256_shuffle_epi32 (v, 0xff );
184+ }
185+ for (int i = 0 ; i < 16 ; ++i) {
186+ aux[i] = _mm256_mullo_epi32 (aux[i], mka);
187+ }
188+ auto mask = _mm256_set1_epi32 (0x3f3f3f3f );
189+ for (int i = 0 ; i < 16 ; ++i) {
190+ aux[i] = _mm256_and_si256 (aux[i], mask);
191+ }
192+ auto offset = _mm256_set1_epi32 (-126 );
193+ #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
194+ auto m1 = _mm256_set1_epi32 (0x01010101 );
195+ #endif
196+ for (int i = 0 ; i < 16 ; ++i) {
197+ #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
198+ aux[i] = _mm256_dpbusd_epi32 (offset, aux[i], m1);
199+ #else
200+ auto dot = _mm256_maddubs_epi16 (aux[i], _mm256_set1_epi32 (0x01010101 ));
201+ aux[i] = _mm256_add_epi32 (offset, _mm256_madd_epi16 (dot, _mm256_set1_epi16 (1 )));
202+ #endif
203+ }
204+ for (int k = 0 ; k < 4 ; ++k) {
205+ auto v1 = _mm256_packs_epi32 (aux[4 *k+0 ], aux[4 *k+1 ]);
206+ auto v2 = _mm256_packs_epi32 (aux[4 *k+2 ], aux[4 *k+3 ]);
207+ result[k] = _mm256_permutevar8x32_epi32 (_mm256_packs_epi16 (v1, v2), shuffle);
208+ }
209+ if constexpr (is_abs) {
210+ for (int k = 0 ; k < 4 ; ++k) {
211+ result[k] = _mm256_sign_epi8 (result[k], result[k]);
212+ }
213+ }
214+ }
215+ IQK_ALWAYS_INLINE inline void next_128 (const uint16_t * val, uint32_t v0, __m256i * result) const {
216+ // Even though we only have 16 vector registers nn AVX2, this is still faster
217+ __m256i aux[16 ];
218+ for (int k = 0 ; k < 4 ; ++k) {
219+ auto v128 = _mm_add_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(val + 4 *k))), _mm_set1_epi32 (v0));
220+ auto v = MM256_SET_M128I (v128, v128);
221+ aux[4 *k+0 ] = _mm256_shuffle_epi32 (v, 0x00 );
222+ aux[4 *k+1 ] = _mm256_shuffle_epi32 (v, 0x55 );
223+ aux[4 *k+2 ] = _mm256_shuffle_epi32 (v, 0xaa );
224+ aux[4 *k+3 ] = _mm256_shuffle_epi32 (v, 0xff );
225+ }
226+ for (int i = 0 ; i < 16 ; ++i) {
227+ aux[i] = _mm256_mullo_epi32 (aux[i], mka);
228+ }
229+ auto mask = _mm256_set1_epi32 (0x3f3f3f3f );
230+ for (int i = 0 ; i < 16 ; ++i) {
231+ aux[i] = _mm256_and_si256 (aux[i], mask);
232+ }
233+ auto offset = _mm256_set1_epi32 (-126 );
234+ #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
235+ auto m1 = _mm256_set1_epi32 (0x01010101 );
236+ #endif
237+ for (int i = 0 ; i < 16 ; ++i) {
238+ #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
239+ aux[i] = _mm256_dpbusd_epi32 (offset, aux[i], m1);
240+ #else
241+ auto dot = _mm256_maddubs_epi16 (aux[i], _mm256_set1_epi32 (0x01010101 ));
242+ aux[i] = _mm256_add_epi32 (offset, _mm256_madd_epi16 (dot, _mm256_set1_epi16 (1 )));
243+ #endif
244+ }
245+ for (int k = 0 ; k < 4 ; ++k) {
246+ auto v1 = _mm256_packs_epi32 (aux[4 *k+0 ], aux[4 *k+1 ]);
247+ auto v2 = _mm256_packs_epi32 (aux[4 *k+2 ], aux[4 *k+3 ]);
248+ result[k] = _mm256_permutevar8x32_epi32 (_mm256_packs_epi16 (v1, v2), shuffle);
249+ }
250+ if constexpr (is_abs) {
251+ for (int k = 0 ; k < 4 ; ++k) {
252+ result[k] = _mm256_sign_epi8 (result[k], result[k]);
253+ }
254+ }
255+ }
173256 inline __m256i next32 (const uint16_t * val, uint32_t v0) const {
174257 const __m256i offset = _mm256_set1_epi32 (-126 );
175258 __m256i aux[4 ];
@@ -385,7 +468,7 @@ void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
385468 assert (n%QK_K == 0 );
386469 const int nb = n/QK_K;
387470
388- Trellis3<true > trellis;
471+ Trellis3<true , false > trellis;
389472
390473 auto shifts = _mm_set_epi32 (0 , 0 , 4 , 0 );
391474 auto values = _mm_loadu_si128 ((const __m128i *)iq4k_values);
@@ -425,8 +508,6 @@ void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
425508 }
426509 };
427510
428- // auto m126 = _mm256_set1_ps(-126.f);
429-
430511 for (int ix = 0 ; ix < nrc_x; ++ix) {
431512 const float * dptr = (const float *)((const char *)vx + ix*bx);
432513 auto d = _mm256_set1_ps (dptr[0 ] * 1 .05f );
@@ -446,17 +527,13 @@ void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
446527 scales[0 ] = _mm256_set_m128 (scales_l, scales_l);
447528 scales[1 ] = _mm256_set_m128 (scales_h, scales_h);
448529 for (int i128 = 0 ; i128 < 2 ; ++i128 ) {
449- // for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
450- for (int k = 0 ; k < 4 ; ++k) xv[k] = trellis.next32 (ql + 16 *i128 + 4 *k, 4096 );
530+ trellis.next_128 (ql + 16 *i128 , 4096 , xv);
451531 for (int iy = 0 ; iy < nrc_y; ++iy) {
452532 const block_q8_2_x4& yb = y[iy][2 *i+i128 ];
453- auto dy = _mm256_castsi256_ps (_mm256_slli_epi32 (_mm256_cvtepu16_epi32 (_mm_loadu_si128 ((const __m128i *)yb.d )), 16 ));
454- dy = _mm256_mul_ps (scales[i128 ], dy);
455- auto d8 = _mm256_set_m128 (_mm256_castps256_ps128 (dy), _mm256_castps256_ps128 (dy));
456- // auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
533+ auto dy4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)yb.d )), 16 ));
534+ auto dy8 = _mm256_mul_ps (scales[i128 ], _mm256_set_m128 (dy4, dy4));
457535 compute_dot (yb.qs );
458- accd[iy] = _mm256_fmadd_ps (d8, sum_4 (), accd[iy]);
459- // accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]);
536+ accd[iy] = _mm256_fmadd_ps (dy8, sum_4 (), accd[iy]);
460537 }
461538 }
462539 }
@@ -595,18 +672,17 @@ void mul_mat_iq3_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
595672 scales[1 ] = _mm256_set_m128 (scales_h, scales_h);
596673 auto mask = _mm256_set1_epi8 (1 );
597674 for (int i128 = 0 ; i128 < 2 ; ++i128 ) {
675+ trellis.next_128 (ql + 16 *i128 , 4096 , xv);
598676 for (int k = 0 ; k < 4 ; ++k) {
599- xv[k] = trellis.next32 (ql + 16 *i128 + 4 *k, 4096 );
600- sv[k] = _mm256_or_si256 (_mm256_cmpeq_epi8 (_mm256_and_si256 (sign_bits, mask), mask), _mm256_set1_epi8 (1 ));
601- mask = _mm256_slli_epi16 (mask, 1 );
677+ sv[k] = _mm256_or_si256 (_mm256_cmpeq_epi8 (_mm256_and_si256 (sign_bits, mask), mask), mask);
678+ sign_bits = _mm256_srli_epi16 (sign_bits, 1 );
602679 }
603680 for (int iy = 0 ; iy < nrc_y; ++iy) {
604681 const block_q8_2_x4& yb = y[iy][2 *i+i128 ];
605- auto dy = _mm256_castsi256_ps (_mm256_slli_epi32 (_mm256_cvtepu16_epi32 (_mm_loadu_si128 ((const __m128i *)yb.d )), 16 ));
606- dy = _mm256_mul_ps (scales[i128 ], dy);
607- auto d8 = _mm256_set_m128 (_mm256_castps256_ps128 (dy), _mm256_castps256_ps128 (dy));
682+ auto dy4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)yb.d )), 16 ));
683+ auto dy8 = _mm256_mul_ps (scales[i128 ], _mm256_set_m128 (dy4, dy4));
608684 compute_dot (yb.qs );
609- accd[iy] = _mm256_fmadd_ps (d8 , sum_4 (), accd[iy]);
685+ accd[iy] = _mm256_fmadd_ps (dy8 , sum_4 (), accd[iy]);
610686 }
611687 }
612688 }
@@ -877,8 +953,6 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
877953 }
878954 };
879955
880- // auto m126 = _mm256_set1_ps(-126.f);
881-
882956 for (int ix = 0 ; ix < nrc_x; ++ix) {
883957 const float * dptr = (const float *)((const char *)vx + ix*bx);
884958 auto d = _mm256_set1_ps (dptr[0 ]);
@@ -900,27 +974,27 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
900974 scales[1 ] = _mm256_set_m128 (scales_h, scales_h);
901975 o_helper.vec = _mm256_add_epi32 (_mm256_slli_epi32 (_mm256_and_si256 (vshb, _mm256_set1_epi32 (1 )), 15 ), _mm256_set1_epi32 (4096 ));
902976 for (int ib = 0 ; ib < 4 ; ++ib) {
903- for (int j = 0 ; j < 4 ; ++j) {
904- const uint32_t sh1 = shb[ib+0 ] >> (8 + 6 *j);
905- const uint32_t sh2 = shb[ib+4 ] >> (8 + 6 *j);
906- values[8 *ib+2 *j+ 0 ] = ql[8 *ib+2 *j+ 0 ] + ((qh[8 *ib+2 *j+0 ] << 8 ) & 0xf00 ) + ((sh1 & 7 ) << 12 ) + o_helper.val [ib+0 ];
907- values[8 *ib+2 *j+ 1 ] = ql[8 *ib+2 *j+ 1 ] + ((qh[8 *ib+2 *j+1 ] << 8 ) & 0xf00 ) + ((sh1 & 56 ) << 9 ) + o_helper.val [ib+0 ];
908- values[8 *ib+2 *j+32 ] = ql[8 *ib+2 *j+32 ] + ((qh[8 *ib+2 *j+0 ] << 4 ) & 0xf00 ) + ((sh2 & 7 ) << 12 ) + o_helper.val [ib+4 ];
909- values[8 *ib+2 *j+33 ] = ql[8 *ib+2 *j+33 ] + ((qh[8 *ib+2 *j+1 ] << 4 ) & 0xf00 ) + ((sh2 & 56 ) << 9 ) + o_helper.val [ib+4 ];
977+ for (int j = 0 ; j < 2 ; ++j) {
978+ const uint32_t sh1 = shb[ib+0 ] >> (8 + 12 *j);
979+ const uint32_t sh2 = shb[ib+4 ] >> (8 + 12 *j);
980+ values[8 *ib+4 *j+ 0 ] = ql[8 *ib+4 *j+ 0 ] + ((qh[8 *ib+4 *j+0 ] << 8 ) & 0xf00 ) + ((sh1 << 12 ) & 0x7000 ) + o_helper.val [ib+0 ];
981+ values[8 *ib+4 *j+ 1 ] = ql[8 *ib+4 *j+ 1 ] + ((qh[8 *ib+4 *j+1 ] << 8 ) & 0xf00 ) + ((sh1 << 9 ) & 0x7000 ) + o_helper.val [ib+0 ];
982+ values[8 *ib+4 *j+ 2 ] = ql[8 *ib+4 *j+ 2 ] + ((qh[8 *ib+4 *j+2 ] << 8 ) & 0xf00 ) + ((sh1 << 6 ) & 0x7000 ) + o_helper.val [ib+0 ];
983+ values[8 *ib+4 *j+ 3 ] = ql[8 *ib+4 *j+ 3 ] + ((qh[8 *ib+4 *j+3 ] << 8 ) & 0xf00 ) + ((sh1 << 3 ) & 0x7000 ) + o_helper.val [ib+0 ];
984+ values[8 *ib+4 *j+32 ] = ql[8 *ib+4 *j+32 ] + ((qh[8 *ib+4 *j+0 ] << 4 ) & 0xf00 ) + ((sh2 << 12 ) & 0x7000 ) + o_helper.val [ib+4 ];
985+ values[8 *ib+4 *j+33 ] = ql[8 *ib+4 *j+33 ] + ((qh[8 *ib+4 *j+1 ] << 4 ) & 0xf00 ) + ((sh2 << 9 ) & 0x7000 ) + o_helper.val [ib+4 ];
986+ values[8 *ib+4 *j+34 ] = ql[8 *ib+4 *j+34 ] + ((qh[8 *ib+4 *j+2 ] << 4 ) & 0xf00 ) + ((sh2 << 6 ) & 0x7000 ) + o_helper.val [ib+4 ];
987+ values[8 *ib+4 *j+35 ] = ql[8 *ib+4 *j+35 ] + ((qh[8 *ib+4 *j+3 ] << 4 ) & 0xf00 ) + ((sh2 << 3 ) & 0x7000 ) + o_helper.val [ib+4 ];
910988 }
911989 }
912990 for (int i128 = 0 ; i128 < 2 ; ++i128 ) {
913- // for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
914- for (int k = 0 ; k < 4 ; ++k) xv[k] = trellis.next32 (values + 32 *i128 + 8 *k);
991+ trellis.next_128 (values + 32 *i128 , xv);
915992 for (int iy = 0 ; iy < nrc_y; ++iy) {
916993 const block_q8_2_x4& yb = y[iy][2 *i+i128 ];
917- auto dy = _mm256_castsi256_ps (_mm256_slli_epi32 (_mm256_cvtepu16_epi32 (_mm_loadu_si128 ((const __m128i *)yb.d )), 16 ));
918- dy = _mm256_mul_ps (scales[i128 ], dy);
919- auto d8 = _mm256_set_m128 (_mm256_castps256_ps128 (dy), _mm256_castps256_ps128 (dy));
920- // auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
994+ auto dy4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)yb.d )), 16 ));
995+ auto dy8 = _mm256_mul_ps (scales[i128 ], _mm256_set_m128 (dy4, dy4));
921996 compute_dot (yb.qs );
922- accd[iy] = _mm256_fmadd_ps (d8, sum_4 (), accd[iy]);
923- // accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]);
997+ accd[iy] = _mm256_fmadd_ps (dy8, sum_4 (), accd[iy]);
924998 }
925999 }
9261000 }
@@ -1020,6 +1094,9 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
10201094 if (typeA == GGML_TYPE_IQ4_KT) {
10211095 if (typeB == GGML_TYPE_Q8_2_X4) {
10221096 IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_iq4_kt_q8_2_x4_T, kernels);
1097+ #ifdef HAVE_FANCY_SIMD
1098+ func16 = mul_mat_iq4_kt_q8_2_x4_T<16 >;
1099+ #endif
10231100 return true ;
10241101 }
10251102 return false ;
@@ -1028,6 +1105,9 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
10281105 if (typeA == GGML_TYPE_IQ2_KT) {
10291106 if (typeB == GGML_TYPE_Q8_2_X4) {
10301107 IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_iq2_kt_q8_2_x4_T, kernels);
1108+ #ifdef HAVE_FANCY_SIMD
1109+ func16 = mul_mat_iq2_kt_q8_2_x4_T<16 >;
1110+ #endif
10311111 return true ;
10321112 }
10331113 return false ;
@@ -1036,6 +1116,9 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
10361116 if (typeA == GGML_TYPE_IQ3_KT) {
10371117 if (typeB == GGML_TYPE_Q8_2_X4) {
10381118 IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_iq3_kt_q8_2_x4_T, kernels);
1119+ #ifdef HAVE_FANCY_SIMD
1120+ func16 = mul_mat_iq3_kt_q8_2_x4_T<16 >;
1121+ #endif
10391122 return true ;
10401123 }
10411124 return false ;
0 commit comments