@@ -1668,82 +1668,34 @@ static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const Da
16681668}
16691669#endif
16701670
1671- inline float convert_to_q8_k_r8 (int k, int d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
1672- auto max_i16 = _mm256_setzero_si256 ();
1673- for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
1674- auto q16_l = _mm256_cvtepi8_epi16 (_mm256_castsi256_si128 (qx[ib32]));
1675- auto q16_h = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (qx[ib32], 1 ));
1676- q16_l = _mm256_mullo_epi16 (q16_l, _mm256_set1_epi16 (scales[2 *ib32+0 ]));
1677- q16_h = _mm256_mullo_epi16 (q16_h, _mm256_set1_epi16 (scales[2 *ib32+1 ]));
1678- max_i16 = _mm256_max_epi16 (max_i16, _mm256_sign_epi16 (q16_l, q16_l));
1679- max_i16 = _mm256_max_epi16 (max_i16, _mm256_sign_epi16 (q16_h, q16_h));
1680- }
1681- auto max_q32 = _mm256_cvtepi16_epi32 (_mm_max_epi16 (_mm256_castsi256_si128 (max_i16), _mm256_extracti128_si256 (max_i16, 1 )));
1682- auto imax4 = _mm_max_epi32 (_mm256_castsi256_si128 (max_q32), _mm256_extracti128_si256 (max_q32, 1 ));
1683- auto max4 = _mm_cvtepi32_ps (imax4);
1684- max4 = _mm_max_ps (max4, _mm_movehl_ps (max4, max4));
1685- max4 = _mm_max_ss (max4, _mm_movehdup_ps (max4));
1686- bool needs_scaling = true ;
1687- float dnew = _mm_cvtss_f32 (max4) / d0;
1688- if (dnew < 1 .f ) {
1689- dnew = 1 .f ; needs_scaling = false ;
1690- }
1691- auto scale = _mm256_set1_ps (std::abs (dnew) > 1e-9f ? 1 /dnew : 0 .f );
1692- for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
1693- auto q16_l = _mm256_cvtepi8_epi16 (_mm256_castsi256_si128 (qx[ib32]));
1694- auto q16_h = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (qx[ib32], 1 ));
1695- q16_l = _mm256_mullo_epi16 (q16_l, _mm256_set1_epi16 (scales[2 *ib32+0 ]));
1696- q16_h = _mm256_mullo_epi16 (q16_h, _mm256_set1_epi16 (scales[2 *ib32+1 ]));
1697- if (needs_scaling) {
1698- auto i0 = _mm256_cvtepi16_epi32 (_mm256_castsi256_si128 (q16_l));
1699- auto i1 = _mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (q16_l, 1 ));
1700- auto i2 = _mm256_cvtepi16_epi32 (_mm256_castsi256_si128 (q16_h));
1701- auto i3 = _mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (q16_h, 1 ));
1702- i0 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (scale, _mm256_cvtepi32_ps (i0)), _MM_ROUND_NEAREST));
1703- i1 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (scale, _mm256_cvtepi32_ps (i1)), _MM_ROUND_NEAREST));
1704- i2 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (scale, _mm256_cvtepi32_ps (i2)), _MM_ROUND_NEAREST));
1705- i3 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (scale, _mm256_cvtepi32_ps (i3)), _MM_ROUND_NEAREST));
1706- i0 = _mm256_packs_epi32 (i0, i1);
1707- i2 = _mm256_packs_epi32 (i2, i3);
1708- i0 = _mm256_packs_epi16 (i0, i2);
1709- i0 = _mm256_permutevar8x32_epi32 (i0, _mm256_setr_epi32 (0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 ));
1710- _mm256_storeu_si256 ((__m256i *)block, i0);
1711- } else {
1712- // 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
1713- auto i0 = _mm256_packs_epi16 (q16_l, q16_h);
1714- auto i0_l = _mm256_castsi256_si128 (i0);
1715- auto i0_h = _mm256_extracti128_si256 (i0, 1 );
1716- _mm_storeu_si128 ((__m128i *)block+0 , _mm_unpacklo_epi64 (i0_l, i0_h));
1717- _mm_storeu_si128 ((__m128i *)block+1 , _mm_unpackhi_epi64 (i0_l, i0_h));
1718- }
1719- auto qs = (uint32_t *)q8_k + 64 *ib32;
1720- for (int l = 0 ; l < 8 ; ++l) {
1721- qs[8 *l + k] = block[l];
1722- }
1723- }
1724- return dnew;
1725- }
1726-
17271671void iqk_convert_iq1_s_q8_k_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1672+ #ifdef HAVE_FANCY_SIMD
1673+ constexpr int k_nr = 16 ;
1674+ using block_q8_k_r = block_q8_k_r16;
1675+ #else
1676+ constexpr int k_nr = 8 ;
1677+ using block_q8_k_r = block_q8_k_r8;
1678+ #endif
1679+
17281680 GGML_ASSERT (n%QK_K == 0 );
1729- GGML_ASSERT (nrc_x%8 == 0 );
1681+ GGML_ASSERT (nrc_x%k_nr == 0 );
17301682
17311683 int nb = n/QK_K;
17321684
1733- const block_iq1_s * x8[8 ];
1685+ const block_iq1_s * x8[k_nr ];
17341686
1735- block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
1687+ block_q8_k_r * y = (block_q8_k_r *)vy;
17361688
17371689 int16_t ls[16 ];
17381690
17391691 uint32_t block[8 ];
17401692
17411693 __m256i qx[8 ];
17421694
1743- for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
1744- for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq1_s *)((const char *)vx + (ix + k)*bx);
1695+ for (int ix = 0 ; ix < nrc_x; ix += k_nr ) {
1696+ for (int k = 0 ; k < k_nr ; ++k) x8[k] = (const block_iq1_s *)((const char *)vx + (ix + k)*bx);
17451697 for (int i = 0 ; i < nb; ++i) {
1746- for (int k = 0 ; k < 8 ; ++k) {
1698+ for (int k = 0 ; k < k_nr ; ++k) {
17471699 float d = 0 .125f * GGML_FP16_TO_FP32 (x8[k][i].d );
17481700 auto qs = x8[k][i].qs ;
17491701 auto qh = x8[k][i].qh ;
@@ -1759,23 +1711,36 @@ void iqk_convert_iq1_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
17591711 qx[ib32] = value;
17601712 qs += 4 ;
17611713 }
1762- float dnew = convert_to_q8_k_r8 (k, 126 , qx, ls, block, y[i].qs );
1714+ float dnew = convert_to_q8_k_r8<k_nr> (k, 1 . f / 126 , qx, ls, block, y[i].qs );
17631715 y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
17641716 }
1717+ #ifdef HAVE_FANCY_SIMD
1718+ for (int l = 0 ; l < 64 ; ++l) {
1719+ auto v = _mm512_xor_si512 (_mm512_loadu_si512 ((const __m512i *)y[i].qs + l), _mm512_set1_epi8 (-128 ));
1720+ _mm512_storeu_si512 ((__m512i *)y[i].qs + l, v);
1721+ }
1722+ #endif
17651723 }
17661724 y += nb;
17671725 }
17681726}
17691727
17701728void iqk_convert_iq1_m_q8_k_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1729+ #ifdef HAVE_FANCY_SIMD
1730+ constexpr int k_nr = 16 ;
1731+ using block_q8_k_r = block_q8_k_r16;
1732+ #else
1733+ constexpr int k_nr = 8 ;
1734+ using block_q8_k_r = block_q8_k_r8;
1735+ #endif
17711736 GGML_ASSERT (n%QK_K == 0 );
1772- GGML_ASSERT (nrc_x%8 == 0 );
1737+ GGML_ASSERT (nrc_x%k_nr == 0 );
17731738
17741739 int nb = n/QK_K;
17751740
1776- const block_iq1_m * x8[8 ];
1741+ const block_iq1_m * x8[k_nr ];
17771742
1778- block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
1743+ block_q8_k_r * y = (block_q8_k_r *)vy;
17791744
17801745 int16_t ls[16 ];
17811746
@@ -1785,10 +1750,10 @@ void iqk_convert_iq1_m_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
17851750
17861751 auto mask = _mm256_setr_epi32 (0x00000008 , 0x00000008 , 0x00000080 , 0x00000080 , 0x00080000 , 0x00080000 , 0x00800000 , 0x00800000 );
17871752
1788- for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
1789- for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq1_m *)((const char *)vx + (ix + k)*bx);
1753+ for (int ix = 0 ; ix < nrc_x; ix += k_nr ) {
1754+ for (int k = 0 ; k < k_nr ; ++k) x8[k] = (const block_iq1_m *)((const char *)vx + (ix + k)*bx);
17901755 for (int i = 0 ; i < nb; ++i) {
1791- for (int k = 0 ; k < 8 ; ++k) {
1756+ for (int k = 0 ; k < k_nr ; ++k) {
17921757 const uint16_t * sc = (const uint16_t *)x8[k][i].scales ;
17931758 iq1m_scale_t scale;
17941759 scale.u16 = (sc[0 ] >> 12 ) | ((sc[1 ] >> 8 ) & 0x00f0 ) | ((sc[2 ] >> 4 ) & 0x0f00 ) | (sc[3 ] & 0xf000 );
@@ -1816,9 +1781,15 @@ void iqk_convert_iq1_m_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
18161781 qs += 4 ;
18171782 qh += 2 ;
18181783 }
1819- float dnew = convert_to_q8_k_r8 (k, 126 , qx, ls, block, y[i].qs );
1784+ float dnew = convert_to_q8_k_r8<k_nr> (k, 1 . f / 126 , qx, ls, block, y[i].qs );
18201785 y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
18211786 }
1787+ #ifdef HAVE_FANCY_SIMD
1788+ for (int l = 0 ; l < 64 ; ++l) {
1789+ auto v = _mm512_xor_si512 (_mm512_loadu_si512 ((const __m512i *)y[i].qs + l), _mm512_set1_epi8 (-128 ));
1790+ _mm512_storeu_si512 ((__m512i *)y[i].qs + l, v);
1791+ }
1792+ #endif
18221793 }
18231794 y += nb;
18241795 }
0 commit comments