@@ -170,6 +170,40 @@ struct Trellis3 {
170170 return _mm256_permutevar8x32_epi32 (aux[0 ], shuffle);
171171 }
172172 }
173+ IQK_ALWAYS_INLINE inline void next_128 (const uint32_t * val, __m256i * result) const {
174+ __m256i aux[16 ];
175+ auto perm = _mm256_setr_epi32 (0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 );
176+ for (int k = 0 ; k < 4 ; ++k) {
177+ auto v = _mm256_loadu_si256 ((const __m256i *)val + k);
178+ v = _mm256_permutevar8x32_epi32 (v, perm);
179+ aux[4 *k+0 ] = _mm256_shuffle_epi32 (v, 0x00 );
180+ aux[4 *k+1 ] = _mm256_shuffle_epi32 (v, 0x55 );
181+ aux[4 *k+2 ] = _mm256_shuffle_epi32 (v, 0xaa );
182+ aux[4 *k+3 ] = _mm256_shuffle_epi32 (v, 0xff );
183+ }
184+ for (int i = 0 ; i < 16 ; ++i) {
185+ aux[i] = _mm256_mullo_epi32 (aux[i], mka);
186+ }
187+ auto mask = _mm256_set1_epi32 (0x3f3f3f3f );
188+ for (int i = 0 ; i < 16 ; ++i) {
189+ aux[i] = _mm256_and_si256 (aux[i], mask);
190+ }
191+ auto offset = _mm256_set1_epi32 (-126 );
192+ auto m1 = _mm256_set1_epi32 (0x01010101 );
193+ for (int i = 0 ; i < 16 ; ++i) {
194+ aux[i] = _mm256_dpbusd_epi32 (offset, aux[i], m1);
195+ }
196+ for (int k = 0 ; k < 4 ; ++k) {
197+ auto v1 = _mm256_packs_epi32 (aux[4 *k+0 ], aux[4 *k+1 ]);
198+ auto v2 = _mm256_packs_epi32 (aux[4 *k+2 ], aux[4 *k+3 ]);
199+ result[k] = _mm256_permutevar8x32_epi32 (_mm256_packs_epi16 (v1, v2), shuffle);
200+ }
201+ if constexpr (is_abs) {
202+ for (int k = 0 ; k < 4 ; ++k) {
203+ result[k] = _mm256_sign_epi8 (result[k], result[k]);
204+ }
205+ }
206+ }
173207 IQK_ALWAYS_INLINE inline void next_128 (const uint16_t * val, uint32_t v0, __m256i * result) const {
174208 __m256i aux[16 ];
175209 for (int k = 0 ; k < 4 ; ++k) {
@@ -951,17 +985,14 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
951985 }
952986 }
953987 for (int i128 = 0 ; i128 < 2 ; ++i128 ) {
954- // for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true> (values + 32*i128 + 8*k);
955- for ( int k = 0 ; k < 4 ; ++k) xv[k] = trellis.next32 (values + 32 *i128 + 8 *k );
988+ // for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k);
989+ trellis.next_128 (values + 32 *i128 , xv );
956990 for (int iy = 0 ; iy < nrc_y; ++iy) {
957991 const block_q8_2_x4& yb = y[iy][2 *i+i128 ];
958- auto dy = _mm256_castsi256_ps (_mm256_slli_epi32 (_mm256_cvtepu16_epi32 (_mm_loadu_si128 ((const __m128i *)yb.d )), 16 ));
959- dy = _mm256_mul_ps (scales[i128 ], dy);
960- auto d8 = _mm256_set_m128 (_mm256_castps256_ps128 (dy), _mm256_castps256_ps128 (dy));
961- // auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
992+ auto dy4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)yb.d )), 16 ));
993+ auto dy8 = _mm256_mul_ps (scales[i128 ], _mm256_set_m128 (dy4, dy4));
962994 compute_dot (yb.qs );
963- accd[iy] = _mm256_fmadd_ps (d8, sum_4 (), accd[iy]);
964- // accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]);
995+ accd[iy] = _mm256_fmadd_ps (dy8, sum_4 (), accd[iy]);
965996 }
966997 }
967998 }
0 commit comments