@@ -1839,37 +1839,34 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
18391839 }
18401840}
18411841
1842- inline float convert_to_q8_k_r8 (int k, int d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
1842+ inline float convert_to_q8_k_r8 (int k, float d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
18431843 auto max_i16 = _mm256_setzero_si256 ();
1844+ __m256i qs[16 ];
18441845 for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
1845- auto q16_l = _mm256_cvtepi8_epi16 (_mm256_castsi256_si128 (qx[ib32]));
1846- auto q16_h = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (qx[ib32], 1 ));
1847- q16_l = _mm256_mullo_epi16 (q16_l , _mm256_set1_epi16 (scales[2 *ib32+0 ]));
1848- q16_h = _mm256_mullo_epi16 (q16_h , _mm256_set1_epi16 (scales[2 *ib32+1 ]));
1849- max_i16 = _mm256_max_epi16 (max_i16, _mm256_sign_epi16 (q16_l, q16_l ));
1850- max_i16 = _mm256_max_epi16 (max_i16, _mm256_sign_epi16 (q16_h, q16_h ));
1846+ qs[ 2 *ib32+ 0 ] = _mm256_cvtepi8_epi16 (_mm256_castsi256_si128 (qx[ib32]));
1847+ qs[ 2 *ib32+ 1 ] = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (qx[ib32], 1 ));
1848+ qs[ 2 *ib32+ 0 ] = _mm256_mullo_epi16 (qs[ 2 *ib32+ 0 ] , _mm256_set1_epi16 (scales[2 *ib32+0 ]));
1849+ qs[ 2 *ib32+ 1 ] = _mm256_mullo_epi16 (qs[ 2 *ib32+ 1 ] , _mm256_set1_epi16 (scales[2 *ib32+1 ]));
1850+ max_i16 = _mm256_max_epi16 (max_i16, _mm256_sign_epi16 (qs[ 2 *ib32+ 0 ], qs[ 2 *ib32+ 0 ] ));
1851+ max_i16 = _mm256_max_epi16 (max_i16, _mm256_sign_epi16 (qs[ 2 *ib32+ 1 ], qs[ 2 *ib32+ 1 ] ));
18511852 }
18521853 auto max_q32 = _mm256_cvtepi16_epi32 (_mm_max_epi16 (_mm256_castsi256_si128 (max_i16), _mm256_extracti128_si256 (max_i16, 1 )));
18531854 auto imax4 = _mm_max_epi32 (_mm256_castsi256_si128 (max_q32), _mm256_extracti128_si256 (max_q32, 1 ));
18541855 auto max4 = _mm_cvtepi32_ps (imax4);
18551856 max4 = _mm_max_ps (max4, _mm_movehl_ps (max4, max4));
18561857 max4 = _mm_max_ss (max4, _mm_movehdup_ps (max4));
18571858 bool needs_scaling = true ;
1858- float dnew = _mm_cvtss_f32 (max4) / d0;
1859+ float dnew = _mm_cvtss_f32 (max4) * d0;
18591860 if (dnew < 1 .f ) {
18601861 dnew = 1 .f ; needs_scaling = false ;
18611862 }
18621863 auto scale = _mm256_set1_ps (std::abs (dnew) > 1e-9f ? 1 /dnew : 0 .f );
18631864 for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
1864- auto q16_l = _mm256_cvtepi8_epi16 (_mm256_castsi256_si128 (qx[ib32]));
1865- auto q16_h = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (qx[ib32], 1 ));
1866- q16_l = _mm256_mullo_epi16 (q16_l, _mm256_set1_epi16 (scales[2 *ib32+0 ]));
1867- q16_h = _mm256_mullo_epi16 (q16_h, _mm256_set1_epi16 (scales[2 *ib32+1 ]));
18681865 if (needs_scaling) {
1869- auto i0 = _mm256_cvtepi16_epi32 (_mm256_castsi256_si128 (q16_l ));
1870- auto i1 = _mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (q16_l , 1 ));
1871- auto i2 = _mm256_cvtepi16_epi32 (_mm256_castsi256_si128 (q16_h ));
1872- auto i3 = _mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (q16_h , 1 ));
1866+ auto i0 = _mm256_cvtepi16_epi32 (_mm256_castsi256_si128 (qs[ 2 *ib32+ 0 ] ));
1867+ auto i1 = _mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (qs[ 2 *ib32+ 0 ] , 1 ));
1868+ auto i2 = _mm256_cvtepi16_epi32 (_mm256_castsi256_si128 (qs[ 2 *ib32+ 1 ] ));
1869+ auto i3 = _mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (qs[ 2 *ib32+ 1 ] , 1 ));
18731870 i0 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (scale, _mm256_cvtepi32_ps (i0)), _MM_ROUND_NEAREST));
18741871 i1 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (scale, _mm256_cvtepi32_ps (i1)), _MM_ROUND_NEAREST));
18751872 i2 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (scale, _mm256_cvtepi32_ps (i2)), _MM_ROUND_NEAREST));
@@ -1881,7 +1878,7 @@ inline float convert_to_q8_k_r8(int k, int d0, const __m256i * qx, const int16_t
18811878 _mm256_storeu_si256 ((__m256i *)block, i0);
18821879 } else {
18831880 // 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
1884- auto i0 = _mm256_packs_epi16 (q16_l, q16_h );
1881+ auto i0 = _mm256_packs_epi16 (qs[ 2 *ib32+ 0 ], qs[ 2 *ib32+ 1 ] );
18851882 auto i0_l = _mm256_castsi256_si128 (i0);
18861883 auto i0_h = _mm256_extracti128_si256 (i0, 1 );
18871884 _mm_storeu_si128 ((__m128i *)block+0 , _mm_unpacklo_epi64 (i0_l, i0_h));
@@ -1976,7 +1973,7 @@ void iqk_convert_iq2_xxs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, i
19761973 values[ib32] = _mm256_set_epi64x (iq2xxs_grid[aux8[3 ]], iq2xxs_grid[aux8[2 ]], iq2xxs_grid[aux8[1 ]], iq2xxs_grid[aux8[0 ]]);
19771974 esh.sign_value (aux32[1 ], values[ib32]);
19781975 }
1979- float dnew = convert_to_q8_k_r8 (k, 124 , values, ls, block, y[i].qs );
1976+ float dnew = convert_to_q8_k_r8 (k, 1 . f / 124 , values, ls, block, y[i].qs );
19801977 y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
19811978 }
19821979 }
@@ -2020,7 +2017,7 @@ void iqk_convert_iq2_xs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, in
20202017 DequantizerIQ2XS::sign_values_helper (q2l, sign_helper, qx+0 );
20212018 DequantizerIQ2XS::sign_values_helper (q2h, sign_helper, qx+4 );
20222019#endif
2023- float dnew = convert_to_q8_k_r8 (k, 124 , qx, helper.val , block, y[i].qs );
2020+ float dnew = convert_to_q8_k_r8 (k, 1 . f / 124 , qx, helper.val , block, y[i].qs );
20242021 y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
20252022 }
20262023 }
@@ -2333,7 +2330,7 @@ void iqk_convert_iq2_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
23332330 helper.vec = DequantizerIQ2S::make_scales (x8[k][i].scales );
23342331 DequantizerIQ2S::prepare (x8[k][i].qs + 0 , x8[k][i].qh +0 , (const uint16_t *)(x8[k][i].qs + QK_K/8 ) + 0 , sh, qx+0 );
23352332 DequantizerIQ2S::prepare (x8[k][i].qs +16 , x8[k][i].qh +4 , (const uint16_t *)(x8[k][i].qs + QK_K/8 ) + 8 , sh, qx+4 );
2336- float dnew = convert_to_q8_k_r8 (k, 124 , qx, helper.val , block, y[i].qs );
2333+ float dnew = convert_to_q8_k_r8 (k, 1 . f / 124 , qx, helper.val , block, y[i].qs );
23372334 y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
23382335 }
23392336 }
@@ -2480,7 +2477,6 @@ void iqk_convert_iq3_xxs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, i
24802477 for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
24812478 for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq3_xxs *)((const char *)vx + (ix + k)*bx);
24822479 for (int i = 0 ; i < nb; ++i) {
2483- // TODO: simdify
24842480 for (int k = 0 ; k < 8 ; ++k) {
24852481 float d = 0 .25f * GGML_FP16_TO_FP32 (x8[k][i].d );
24862482 auto qs = x8[k][i].qs ;
@@ -2494,7 +2490,7 @@ void iqk_convert_iq3_xxs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, i
24942490 esh.sign_value (aux32, values[ib32]);
24952491 qs += 8 ;
24962492 }
2497- float dnew = convert_to_q8_k_r8 (k, 124 , values, ls, block, y[i].qs );
2493+ float dnew = convert_to_q8_k_r8 (k, 1 . f / 124 , values, ls, block, y[i].qs );
24982494 y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
24992495 }
25002496 }
@@ -2589,7 +2585,7 @@ void iqk_convert_iq3_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
25892585 ls[2 *ib32 + 0 ] = (2 *((x8[k][i].scales [ib32/2 ] >> 4 *(ib32%2 )) & 0xf ) + 1 );
25902586 ls[2 *ib32 + 1 ] = ls[2 *ib32 + 0 ];
25912587 }
2592- float dnew = convert_to_q8_k_r8 (k, 127 , values, ls, block, y[i].qs );
2588+ float dnew = convert_to_q8_k_r8 (k, 1 . f / 127 , values, ls, block, y[i].qs );
25932589 y[i].d [k] = GGML_FP32_TO_FP16 (d*dnew);
25942590 }
25952591 }
@@ -2666,7 +2662,6 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX
26662662
26672663bool iqk_set_kernels_iquants (int ne00, int typeA, int typeB, std::array<mul_mat_t , IQK_MAX_NY>& kernels, mul_mat_t & func16) {
26682664
2669- // if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
26702665 if (ne00%QK_K != 0 ) return false ;
26712666
26722667 // if (ggml_type(typeA) == GGML_TYPE_IQ2_XXS) {
@@ -3355,6 +3350,20 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
33553350
33563351}
33573352
3353+ bool iqk_convert_iquants_q80_r8 ([[maybe_unused]] int type, int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, int nrc_x) {
3354+ if (n%QK_K != 0 || nrc_x%8 != 0 ) return false ;
3355+ return false ;
3356+ // switch (ggml_type(type)) {
3357+ // case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
3358+ // case GGML_TYPE_IQ2_XS : iqk_convert_iq2_xs_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3359+ // case GGML_TYPE_IQ2_S : iqk_convert_iq2_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3360+ // case GGML_TYPE_IQ3_XXS: iqk_convert_iq3_xxs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
3361+ // case GGML_TYPE_IQ3_S : iqk_convert_iq3_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3362+ // default: return false;
3363+ // }
3364+ // return true;
3365+ }
3366+
33583367bool iqk_set_kernels_iquants (int ne00, int typeA, int typeB, std::array<mul_mat_t , IQK_MAX_NY>& kernels, mul_mat_t & func16) {
33593368
33603369 if (ne00%QK_K != 0 || ggml_type (typeB) != GGML_TYPE_Q8_K) {
0 commit comments