@@ -170,6 +170,48 @@ struct Trellis3 {
170170 return _mm256_permutevar8x32_epi32 (aux[0 ], shuffle);
171171 }
172172 }
173+ IQK_ALWAYS_INLINE inline void next_128 (const uint16_t * val, uint32_t v0, __m256i * result) const {
174+ __m256i aux[16 ];
175+ for (int k = 0 ; k < 4 ; ++k) {
176+ auto v128 = _mm_add_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(val + 4 *k))), _mm_set1_epi32 (v0));
177+ auto v = MM256_SET_M128I (v128, v128);
178+ aux[4 *k+0 ] = _mm256_shuffle_epi32 (v, 0x00 );
179+ aux[4 *k+1 ] = _mm256_shuffle_epi32 (v, 0x55 );
180+ aux[4 *k+2 ] = _mm256_shuffle_epi32 (v, 0xaa );
181+ aux[4 *k+3 ] = _mm256_shuffle_epi32 (v, 0xff );
182+ }
183+ for (int i = 0 ; i < 16 ; ++i) {
184+ aux[i] = _mm256_mullo_epi32 (aux[i], mka);
185+ }
186+ auto mask = _mm256_set1_epi32 (0x3f3f3f3f );
187+ for (int i = 0 ; i < 16 ; ++i) {
188+ aux[i] = _mm256_and_si256 (aux[i], mask);
189+ }
190+ auto offset = _mm256_set1_epi32 (-126 );
191+ auto m1 = _mm256_set1_epi32 (0x01010101 );
192+ for (int i = 0 ; i < 16 ; ++i) {
193+ aux[i] = _mm256_dpbusd_epi32 (offset, aux[i], m1);
194+ }
195+ for (int k = 0 ; k < 4 ; ++k) {
196+ auto v1 = _mm256_packs_epi32 (aux[4 *k+0 ], aux[4 *k+1 ]);
197+ auto v2 = _mm256_packs_epi32 (aux[4 *k+2 ], aux[4 *k+3 ]);
198+ result[k] = _mm256_permutevar8x32_epi32 (_mm256_packs_epi16 (v1, v2), shuffle);
199+ }
200+ if constexpr (is_abs) {
201+ for (int k = 0 ; k < 4 ; ++k) {
202+ result[k] = _mm256_sign_epi8 (result[k], result[k]);
203+ }
204+ }
205+ // for (int k = 0; k < 4; ++k) {
206+ // for (int i = 0; i < 4; ++i) {
207+ // aux[i] = _mm256_and_si256(aux[4*k+i], _mm256_set1_epi32(0x3f3f3f3f));
208+ // aux[i] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), aux[i]);
209+ // }
210+ // aux[0] = _mm256_packs_epi32(aux[0], aux[1]);
211+ // aux[2] = _mm256_packs_epi32(aux[2], aux[3]);
212+ // result[k] = _mm256_permutevar8x32_epi32(_mm256_packs_epi16(aux[0], aux[2]), shuffle);
213+ // }
214+ }
173215 inline __m256i next32 (const uint16_t * val, uint32_t v0) const {
174216 const __m256i offset = _mm256_set1_epi32 (-126 );
175217 __m256i aux[4 ];
@@ -385,7 +427,7 @@ void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
385427 assert (n%QK_K == 0 );
386428 const int nb = n/QK_K;
387429
388- Trellis3<true > trellis;
430+ Trellis3<true , false > trellis;
389431
390432 auto shifts = _mm_set_epi32 (0 , 0 , 4 , 0 );
391433 auto values = _mm_loadu_si128 ((const __m128i *)iq4k_values);
@@ -425,8 +467,6 @@ void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
425467 }
426468 };
427469
428- // auto m126 = _mm256_set1_ps(-126.f);
429-
430470 for (int ix = 0 ; ix < nrc_x; ++ix) {
431471 const float * dptr = (const float *)((const char *)vx + ix*bx);
432472 auto d = _mm256_set1_ps (dptr[0 ] * 1 .05f );
@@ -446,17 +486,14 @@ void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
446486 scales[0 ] = _mm256_set_m128 (scales_l, scales_l);
447487 scales[1 ] = _mm256_set_m128 (scales_h, scales_h);
448488 for (int i128 = 0 ; i128 < 2 ; ++i128 ) {
449- // for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32 *i128 + 8*k );
450- for (int k = 0 ; k < 4 ; ++k) xv[k] = trellis.next32 (ql + 16 *i128 + 4 *k, 4096 );
489+ trellis.next_128 (ql + 16 *i128 , 4096 , xv );
490+ // for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096);
451491 for (int iy = 0 ; iy < nrc_y; ++iy) {
452492 const block_q8_2_x4& yb = y[iy][2 *i+i128 ];
453- auto dy = _mm256_castsi256_ps (_mm256_slli_epi32 (_mm256_cvtepu16_epi32 (_mm_loadu_si128 ((const __m128i *)yb.d )), 16 ));
454- dy = _mm256_mul_ps (scales[i128 ], dy);
455- auto d8 = _mm256_set_m128 (_mm256_castps256_ps128 (dy), _mm256_castps256_ps128 (dy));
456- // auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
493+ auto dy4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)yb.d )), 16 ));
494+ auto dy8 = _mm256_mul_ps (scales[i128 ], _mm256_set_m128 (dy4, dy4));
457495 compute_dot (yb.qs );
458- accd[iy] = _mm256_fmadd_ps (d8, sum_4 (), accd[iy]);
459- // accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]);
496+ accd[iy] = _mm256_fmadd_ps (dy8, sum_4 (), accd[iy]);
460497 }
461498 }
462499 }
@@ -595,18 +632,22 @@ void mul_mat_iq3_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
595632 scales[1 ] = _mm256_set_m128 (scales_h, scales_h);
596633 auto mask = _mm256_set1_epi8 (1 );
597634 for (int i128 = 0 ; i128 < 2 ; ++i128 ) {
635+ trellis.next_128 (ql + 16 *i128 , 4096 , xv);
598636 for (int k = 0 ; k < 4 ; ++k) {
599- xv[k] = trellis.next32 (ql + 16 *i128 + 4 *k, 4096 );
600- sv[k] = _mm256_or_si256 (_mm256_cmpeq_epi8 (_mm256_and_si256 (sign_bits, mask), mask), _mm256_set1_epi8 (1 ));
601- mask = _mm256_slli_epi16 (mask, 1 );
637+ sv[k] = _mm256_or_si256 (_mm256_cmpeq_epi8 (_mm256_and_si256 (sign_bits, mask), mask), mask);
638+ sign_bits = _mm256_srli_epi16 (sign_bits, 1 );
602639 }
640+ // for (int k = 0; k < 4; ++k) {
641+ // xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096);
642+ // sv[k] = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sign_bits, mask), mask), _mm256_set1_epi8(1));
643+ // mask = _mm256_slli_epi16(mask, 1);
644+ // }
603645 for (int iy = 0 ; iy < nrc_y; ++iy) {
604646 const block_q8_2_x4& yb = y[iy][2 *i+i128 ];
605- auto dy = _mm256_castsi256_ps (_mm256_slli_epi32 (_mm256_cvtepu16_epi32 (_mm_loadu_si128 ((const __m128i *)yb.d )), 16 ));
606- dy = _mm256_mul_ps (scales[i128 ], dy);
607- auto d8 = _mm256_set_m128 (_mm256_castps256_ps128 (dy), _mm256_castps256_ps128 (dy));
647+ auto dy4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)yb.d )), 16 ));
648+ auto dy8 = _mm256_mul_ps (scales[i128 ], _mm256_set_m128 (dy4, dy4));
608649 compute_dot (yb.qs );
609- accd[iy] = _mm256_fmadd_ps (d8 , sum_4 (), accd[iy]);
650+ accd[iy] = _mm256_fmadd_ps (dy8 , sum_4 (), accd[iy]);
610651 }
611652 }
612653 }
0 commit comments