@@ -171,6 +171,8 @@ struct Trellis3 {
171171 }
172172 }
173173 IQK_ALWAYS_INLINE inline void next_128 (const uint32_t * val, __m256i * result) const {
174+ #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
175+ // On AVX2 we don't have enough vector egisters to do this
174176 __m256i aux[16 ];
175177 auto perm = _mm256_setr_epi32 (0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 );
176178 for (int k = 0 ; k < 4 ; ++k) {
@@ -203,8 +205,13 @@ struct Trellis3 {
203205 result[k] = _mm256_sign_epi8 (result[k], result[k]);
204206 }
205207 }
208+ #else
209+ for (int k = 0 ; k < 4 ; ++k) result[k] = next32 (val + 8 *k);
210+ #endif
206211 }
207212 IQK_ALWAYS_INLINE inline void next_128 (const uint16_t * val, uint32_t v0, __m256i * result) const {
213+ #if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
214+ // On AVX2 we don't have enough vector egisters to do this
208215 __m256i aux[16 ];
209216 for (int k = 0 ; k < 4 ; ++k) {
210217 auto v128 = _mm_add_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(val + 4 *k))), _mm_set1_epi32 (v0));
@@ -236,15 +243,9 @@ struct Trellis3 {
236243 result[k] = _mm256_sign_epi8 (result[k], result[k]);
237244 }
238245 }
239- // for (int k = 0; k < 4; ++k) {
240- // for (int i = 0; i < 4; ++i) {
241- // aux[i] = _mm256_and_si256(aux[4*k+i], _mm256_set1_epi32(0x3f3f3f3f));
242- // aux[i] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), aux[i]);
243- // }
244- // aux[0] = _mm256_packs_epi32(aux[0], aux[1]);
245- // aux[2] = _mm256_packs_epi32(aux[2], aux[3]);
246- // result[k] = _mm256_permutevar8x32_epi32(_mm256_packs_epi16(aux[0], aux[2]), shuffle);
247- // }
246+ #else
247+ for (int k = 0 ; k < 4 ; ++k) result[k] = next32 (val + 4 *k, v0);
248+ #endif
248249 }
249250 inline __m256i next32 (const uint16_t * val, uint32_t v0) const {
250251 const __m256i offset = _mm256_set1_epi32 (-126 );
@@ -521,7 +522,6 @@ void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
521522 scales[1 ] = _mm256_set_m128 (scales_h, scales_h);
522523 for (int i128 = 0 ; i128 < 2 ; ++i128 ) {
523524 trellis.next_128 (ql + 16 *i128 , 4096 , xv);
524- // for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096);
525525 for (int iy = 0 ; iy < nrc_y; ++iy) {
526526 const block_q8_2_x4& yb = y[iy][2 *i+i128 ];
527527 auto dy4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)yb.d )), 16 ));
@@ -671,11 +671,6 @@ void mul_mat_iq3_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
671671 sv[k] = _mm256_or_si256 (_mm256_cmpeq_epi8 (_mm256_and_si256 (sign_bits, mask), mask), mask);
672672 sign_bits = _mm256_srli_epi16 (sign_bits, 1 );
673673 }
674- // for (int k = 0; k < 4; ++k) {
675- // xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096);
676- // sv[k] = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sign_bits, mask), mask), _mm256_set1_epi8(1));
677- // mask = _mm256_slli_epi16(mask, 1);
678- // }
679674 for (int iy = 0 ; iy < nrc_y; ++iy) {
680675 const block_q8_2_x4& yb = y[iy][2 *i+i128 ];
681676 auto dy4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)yb.d )), 16 ));
@@ -952,7 +947,9 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
952947 }
953948 };
954949
955- // auto m126 = _mm256_set1_ps(-126.f);
950+ // auto shift1 = _mm256_setr_epi32(8, 8, 8, 8, 20, 20, 20, 20);
951+ // auto shift2 = _mm256_setr_epi32(12, 9, 6, 3, 12, 9, 6, 3);
952+ // __m256i values[8];
956953
957954 for (int ix = 0 ; ix < nrc_x; ++ix) {
958955 const float * dptr = (const float *)((const char *)vx + ix*bx);
@@ -975,18 +972,42 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
975972 scales[1 ] = _mm256_set_m128 (scales_h, scales_h);
976973 o_helper.vec = _mm256_add_epi32 (_mm256_slli_epi32 (_mm256_and_si256 (vshb, _mm256_set1_epi32 (1 )), 15 ), _mm256_set1_epi32 (4096 ));
977974 for (int ib = 0 ; ib < 4 ; ++ib) {
978- for (int j = 0 ; j < 4 ; ++j) {
979- const uint32_t sh1 = shb[ib+0 ] >> (8 + 6 *j);
980- const uint32_t sh2 = shb[ib+4 ] >> (8 + 6 *j);
981- values[8 *ib+2 *j+ 0 ] = ql[8 *ib+2 *j+ 0 ] + ((qh[8 *ib+2 *j+0 ] << 8 ) & 0xf00 ) + ((sh1 & 7 ) << 12 ) + o_helper.val [ib+0 ];
982- values[8 *ib+2 *j+ 1 ] = ql[8 *ib+2 *j+ 1 ] + ((qh[8 *ib+2 *j+1 ] << 8 ) & 0xf00 ) + ((sh1 & 56 ) << 9 ) + o_helper.val [ib+0 ];
983- values[8 *ib+2 *j+32 ] = ql[8 *ib+2 *j+32 ] + ((qh[8 *ib+2 *j+0 ] << 4 ) & 0xf00 ) + ((sh2 & 7 ) << 12 ) + o_helper.val [ib+4 ];
984- values[8 *ib+2 *j+33 ] = ql[8 *ib+2 *j+33 ] + ((qh[8 *ib+2 *j+1 ] << 4 ) & 0xf00 ) + ((sh2 & 56 ) << 9 ) + o_helper.val [ib+4 ];
975+ // Somehow this is slower.
976+ // auto idxl = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(ql + 8*ib)));
977+ // auto idxh = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(ql + 8*ib + 32)));
978+ // auto vh = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qh + 8*ib)));
979+ // idxl = _mm256_or_si256(idxl, _mm256_and_si256(_mm256_slli_epi32(vh, 8), _mm256_set1_epi32(0xf00)));
980+ // idxh = _mm256_or_si256(idxh, _mm256_and_si256(_mm256_slli_epi32(vh, 4), _mm256_set1_epi32(0xf00)));
981+ // auto shl = _mm256_sllv_epi32(_mm256_srlv_epi32(_mm256_set1_epi32(shb[ib+0]), shift1), shift2);
982+ // auto shh = _mm256_sllv_epi32(_mm256_srlv_epi32(_mm256_set1_epi32(shb[ib+4]), shift1), shift2);
983+ // idxl = _mm256_or_si256(idxl, _mm256_and_si256(shl, _mm256_set1_epi32(0x7000)));
984+ // idxh = _mm256_or_si256(idxh, _mm256_and_si256(shh, _mm256_set1_epi32(0x7000)));
985+ // values[ib+0] = _mm256_add_epi32(idxl, _mm256_set1_epi32(o_helper.val[ib+0]));
986+ // values[ib+4] = _mm256_add_epi32(idxh, _mm256_set1_epi32(o_helper.val[ib+4]));
987+ for (int j = 0 ; j < 2 ; ++j) {
988+ const uint32_t sh1 = shb[ib+0 ] >> (8 + 12 *j);
989+ const uint32_t sh2 = shb[ib+4 ] >> (8 + 12 *j);
990+ // values[8*ib+4*j+ 0] = ql[8*ib+4*j+ 0] + ((qh[8*ib+4*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
991+ // values[8*ib+4*j+ 1] = ql[8*ib+4*j+ 1] + ((qh[8*ib+4*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
992+ // values[8*ib+4*j+ 2] = ql[8*ib+4*j+ 2] + ((qh[8*ib+4*j+2] << 8) & 0xf00) + ((sh1 & 448) << 6) + o_helper.val[ib+0];
993+ // values[8*ib+4*j+ 3] = ql[8*ib+4*j+ 3] + ((qh[8*ib+4*j+3] << 8) & 0xf00) + ((sh1 & 3584) << 3) + o_helper.val[ib+0];
994+ // values[8*ib+4*j+32] = ql[8*ib+4*j+32] + ((qh[8*ib+4*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
995+ // values[8*ib+4*j+33] = ql[8*ib+4*j+33] + ((qh[8*ib+4*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
996+ // values[8*ib+4*j+34] = ql[8*ib+4*j+34] + ((qh[8*ib+4*j+2] << 4) & 0xf00) + ((sh2 & 448) << 6) + o_helper.val[ib+4];
997+ // values[8*ib+4*j+35] = ql[8*ib+4*j+35] + ((qh[8*ib+4*j+3] << 4) & 0xf00) + ((sh2 & 3584) << 3) + o_helper.val[ib+4];
998+ values[8 *ib+4 *j+ 0 ] = ql[8 *ib+4 *j+ 0 ] + ((qh[8 *ib+4 *j+0 ] << 8 ) & 0xf00 ) + ((sh1 << 12 ) & 0x7000 ) + o_helper.val [ib+0 ];
999+ values[8 *ib+4 *j+ 1 ] = ql[8 *ib+4 *j+ 1 ] + ((qh[8 *ib+4 *j+1 ] << 8 ) & 0xf00 ) + ((sh1 << 9 ) & 0x7000 ) + o_helper.val [ib+0 ];
1000+ values[8 *ib+4 *j+ 2 ] = ql[8 *ib+4 *j+ 2 ] + ((qh[8 *ib+4 *j+2 ] << 8 ) & 0xf00 ) + ((sh1 << 6 ) & 0x7000 ) + o_helper.val [ib+0 ];
1001+ values[8 *ib+4 *j+ 3 ] = ql[8 *ib+4 *j+ 3 ] + ((qh[8 *ib+4 *j+3 ] << 8 ) & 0xf00 ) + ((sh1 << 3 ) & 0x7000 ) + o_helper.val [ib+0 ];
1002+ values[8 *ib+4 *j+32 ] = ql[8 *ib+4 *j+32 ] + ((qh[8 *ib+4 *j+0 ] << 4 ) & 0xf00 ) + ((sh2 << 12 ) & 0x7000 ) + o_helper.val [ib+4 ];
1003+ values[8 *ib+4 *j+33 ] = ql[8 *ib+4 *j+33 ] + ((qh[8 *ib+4 *j+1 ] << 4 ) & 0xf00 ) + ((sh2 << 9 ) & 0x7000 ) + o_helper.val [ib+4 ];
1004+ values[8 *ib+4 *j+34 ] = ql[8 *ib+4 *j+34 ] + ((qh[8 *ib+4 *j+2 ] << 4 ) & 0xf00 ) + ((sh2 << 6 ) & 0x7000 ) + o_helper.val [ib+4 ];
1005+ values[8 *ib+4 *j+35 ] = ql[8 *ib+4 *j+35 ] + ((qh[8 *ib+4 *j+3 ] << 4 ) & 0xf00 ) + ((sh2 << 3 ) & 0x7000 ) + o_helper.val [ib+4 ];
9851006 }
9861007 }
9871008 for (int i128 = 0 ; i128 < 2 ; ++i128 ) {
988- // for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k);
9891009 trellis.next_128 (values + 32 *i128 , xv);
1010+ // trellis.next_128(values + 4*i128, xv);
9901011 for (int iy = 0 ; iy < nrc_y; ++iy) {
9911012 const block_q8_2_x4& yb = y[iy][2 *i+i128 ];
9921013 auto dy4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)yb.d )), 16 ));
0 commit comments