@@ -171,8 +171,7 @@ struct Trellis3 {
171171 }
172172 }
173173 IQK_ALWAYS_INLINE inline void next_128 (const uint32_t * val, __m256i * result) const {
174- #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
175- // On AVX2 we don't have enough vector registers to do this
174+ // Even though we only have 16 vector registers nn AVX2, this is still faster
176175 __m256i aux[16 ];
177176 auto perm = _mm256_setr_epi32 (0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 );
178177 for (int k = 0 ; k < 4 ; ++k) {
@@ -191,9 +190,16 @@ struct Trellis3 {
191190 aux[i] = _mm256_and_si256 (aux[i], mask);
192191 }
193192 auto offset = _mm256_set1_epi32 (-126 );
193+ #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
194194 auto m1 = _mm256_set1_epi32 (0x01010101 );
195+ #endif
195196 for (int i = 0 ; i < 16 ; ++i) {
197+ #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
196198 aux[i] = _mm256_dpbusd_epi32 (offset, aux[i], m1);
199+ #else
200+ auto dot = _mm256_maddubs_epi16 (aux[i], _mm256_set1_epi32 (0x01010101 ));
201+ aux[i] = _mm256_add_epi32 (offset, _mm256_madd_epi16 (dot, _mm256_set1_epi16 (1 )));
202+ #endif
197203 }
198204 for (int k = 0 ; k < 4 ; ++k) {
199205 auto v1 = _mm256_packs_epi32 (aux[4 *k+0 ], aux[4 *k+1 ]);
@@ -205,13 +211,9 @@ struct Trellis3 {
205211 result[k] = _mm256_sign_epi8 (result[k], result[k]);
206212 }
207213 }
208- #else
209- for (int k = 0 ; k < 4 ; ++k) result[k] = next32 (val + 8 *k);
210- #endif
211214 }
212215 IQK_ALWAYS_INLINE inline void next_128 (const uint16_t * val, uint32_t v0, __m256i * result) const {
213- #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
214- // On AVX2 we don't have enough vector registers to do this
216+ // Even though we only have 16 vector registers nn AVX2, this is still faster
215217 __m256i aux[16 ];
216218 for (int k = 0 ; k < 4 ; ++k) {
217219 auto v128 = _mm_add_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(val + 4 *k))), _mm_set1_epi32 (v0));
@@ -229,9 +231,16 @@ struct Trellis3 {
229231 aux[i] = _mm256_and_si256 (aux[i], mask);
230232 }
231233 auto offset = _mm256_set1_epi32 (-126 );
234+ #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
232235 auto m1 = _mm256_set1_epi32 (0x01010101 );
236+ #endif
233237 for (int i = 0 ; i < 16 ; ++i) {
238+ #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
234239 aux[i] = _mm256_dpbusd_epi32 (offset, aux[i], m1);
240+ #else
241+ auto dot = _mm256_maddubs_epi16 (aux[i], _mm256_set1_epi32 (0x01010101 ));
242+ aux[i] = _mm256_add_epi32 (offset, _mm256_madd_epi16 (dot, _mm256_set1_epi16 (1 )));
243+ #endif
235244 }
236245 for (int k = 0 ; k < 4 ; ++k) {
237246 auto v1 = _mm256_packs_epi32 (aux[4 *k+0 ], aux[4 *k+1 ]);
@@ -243,9 +252,6 @@ struct Trellis3 {
243252 result[k] = _mm256_sign_epi8 (result[k], result[k]);
244253 }
245254 }
246- #else
247- for (int k = 0 ; k < 4 ; ++k) result[k] = next32 (val + 4 *k, v0);
248- #endif
249255 }
250256 inline __m256i next32 (const uint16_t * val, uint32_t v0) const {
251257 const __m256i offset = _mm256_set1_epi32 (-126 );
0 commit comments