@@ -810,10 +810,11 @@ static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const Data
810810 auto d4_2 = _mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(q8.y [iy][2 *i+1 ].d )));
811811 auto dy = _mm256_castsi256_ps (_mm256_slli_epi32 (MM256_SET_M128I (d4_2, d4_1), 16 ));
812812 _mm256_storeu_ps (d8 + 8 *iy, dy);
813- auto m4_1 = _mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(q8.y [iy][2 *i+0 ].d +4 )));
814- auto m4_2 = _mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(q8.y [iy][2 *i+1 ].d +4 )));
815- auto my = _mm256_castsi256_ps (_mm256_slli_epi32 (MM256_SET_M128I (m4_2, m4_1), 16 ));
816- accd[iy] = _mm256_fmadd_ps (my, mins, accd[iy]);
813+ auto m4_1 = _mm_cvtepi16_epi32 (_mm_loadl_epi64 ((const __m128i *)(q8.y [iy][2 *i+0 ].d +4 )));
814+ auto m4_2 = _mm_cvtepi16_epi32 (_mm_loadl_epi64 ((const __m128i *)(q8.y [iy][2 *i+1 ].d +4 )));
815+ auto myi = MM256_SET_M128I (m4_2, m4_1);
816+ auto my = _mm256_mul_ps (dy, _mm256_cvtepi32_ps (myi));
817+ accd[iy] = _mm256_fmadd_ps (my, mins, accd[iy]);
817818 }
818819
819820 auto all_scales = _mm256_mul_ps (_mm256_set1_ps (deq.d ), _mm256_cvtepi32_ps (_mm256_cvtepu8_epi32 (_mm_loadl_epi64 ((const __m128i *)utmp))));
@@ -2017,6 +2018,91 @@ typedef struct {
20172018 int8_t qs[8 *QK8_1];
20182019} block_q8_1_r8;
20192020
2021+ void iqk_convert_q2_k_q8_k_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
2022+ GGML_ASSERT (n%QK_K == 0 );
2023+ GGML_ASSERT (nrc_x%8 == 0 );
2024+
2025+ int nb = n/QK_K;
2026+
2027+ const block_q2_K * x8[8 ];
2028+
2029+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
2030+
2031+ float f_values[QK_K];
2032+ uint32_t block[8 ];
2033+
2034+ __m256i xv[4 ];
2035+
2036+ auto ml = _mm256_set1_epi8 (0x03 );
2037+ auto sign_bit = _mm256_set1_ps (-0 .0f );
2038+ auto perm = _mm256_setr_epi32 (0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 );
2039+
2040+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
2041+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_q2_K *)((const char *)vx + (ix + k)*bx);
2042+ for (int i = 0 ; i < nb; ++i) {
2043+ for (int k = 0 ; k < 8 ; ++k) {
2044+ auto vd = _mm256_set1_ps (GGML_FP16_TO_FP32 (x8[k][i].d ));
2045+ auto vm = _mm256_mul_ps (_mm256_set1_ps (GGML_FP16_TO_FP32 (x8[k][i].dmin )), _mm256_set1_ps (-1 .f ));
2046+ auto block_max = _mm256_setzero_ps ();
2047+ for (int i128 = 0 ; i128 < 2 ; ++i128 ) {
2048+ auto bits = _mm256_loadu_si256 ((const __m256i *)x8[k][i].qs +i128 );
2049+ xv[0 ] = _mm256_and_si256 (bits, ml);
2050+ xv[1 ] = _mm256_and_si256 (_mm256_srli_epi16 (bits, 2 ), ml);
2051+ xv[2 ] = _mm256_and_si256 (_mm256_srli_epi16 (bits, 4 ), ml);
2052+ xv[3 ] = _mm256_and_si256 (_mm256_srli_epi16 (bits, 6 ), ml);
2053+ for (int l = 0 ; l < 4 ; ++l) {
2054+ auto q1 = _mm256_cvtepi8_epi16 (_mm256_castsi256_si128 (xv[l]));
2055+ auto q2 = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (xv[l], 1 ));
2056+ q1 = _mm256_mullo_epi16 (q1, _mm256_set1_epi16 (x8[k][i].scales [8 *i128 + 2 *l + 0 ] & 0xf ));
2057+ q2 = _mm256_mullo_epi16 (q2, _mm256_set1_epi16 (x8[k][i].scales [8 *i128 + 2 *l + 1 ] & 0xf ));
2058+ auto m1 = _mm256_mul_ps (vm, _mm256_set1_ps (x8[k][i].scales [8 *i128 + 2 *l + 0 ] >> 4 ));
2059+ auto m2 = _mm256_mul_ps (vm, _mm256_set1_ps (x8[k][i].scales [8 *i128 + 2 *l + 1 ] >> 4 ));
2060+ auto v0 = _mm256_fmadd_ps (_mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_castsi256_si128 (q1))), vd, m1);
2061+ auto v1 = _mm256_fmadd_ps (_mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (q1, 1 ))), vd, m1);
2062+ auto v2 = _mm256_fmadd_ps (_mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_castsi256_si128 (q2))), vd, m2);
2063+ auto v3 = _mm256_fmadd_ps (_mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (q2, 1 ))), vd, m2);
2064+ auto max = _mm256_max_ps (_mm256_max_ps (_mm256_andnot_ps (sign_bit, v0), _mm256_andnot_ps (sign_bit, v1)),
2065+ _mm256_max_ps (_mm256_andnot_ps (sign_bit, v2), _mm256_andnot_ps (sign_bit, v3)));
2066+ block_max = _mm256_max_ps (block_max, max);
2067+ _mm256_storeu_ps (f_values + 128 *i128 + 32 *l + 0 , v0);
2068+ _mm256_storeu_ps (f_values + 128 *i128 + 32 *l + 8 , v1);
2069+ _mm256_storeu_ps (f_values + 128 *i128 + 32 *l + 16 , v2);
2070+ _mm256_storeu_ps (f_values + 128 *i128 + 32 *l + 24 , v3);
2071+ }
2072+ }
2073+ auto max4 = _mm_max_ps (_mm256_extractf128_ps (block_max, 1 ), _mm256_castps256_ps128 (block_max));
2074+ max4 = _mm_max_ps (max4, _mm_movehl_ps (max4, max4));
2075+ max4 = _mm_max_ss (max4, _mm_movehdup_ps (max4));
2076+ float d = _mm_cvtss_f32 (max4/127 .f );
2077+ auto id = _mm256_set1_ps (d != 0 .0f ? 1 /d : 0 .0f );
2078+ y[i].d [k] = GGML_FP32_TO_FP16 (d);
2079+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
2080+ auto v0 = _mm256_loadu_ps (f_values + 32 *ib32 + 0 );
2081+ auto v1 = _mm256_loadu_ps (f_values + 32 *ib32 + 8 );
2082+ auto v2 = _mm256_loadu_ps (f_values + 32 *ib32 + 16 );
2083+ auto v3 = _mm256_loadu_ps (f_values + 32 *ib32 + 24 );
2084+ auto i0 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (v0, id), _MM_ROUND_NEAREST));
2085+ auto i1 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (v1, id), _MM_ROUND_NEAREST));
2086+ auto i2 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (v2, id), _MM_ROUND_NEAREST));
2087+ auto i3 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (v3, id), _MM_ROUND_NEAREST));
2088+ i0 = _mm256_packs_epi32 (i0, i1);
2089+ i2 = _mm256_packs_epi32 (i2, i3);
2090+ i0 = _mm256_packs_epi16 (i0, i2);
2091+ i0 = _mm256_permutevar8x32_epi32 (i0, perm);
2092+
2093+ _mm256_storeu_si256 ((__m256i *)block, i0);
2094+ auto q8 = (uint32_t *)y[i].qs + 64 *ib32;
2095+ for (int l = 0 ; l < 4 ; ++l) {
2096+ q8[8 *l + k + 0 ] = block[l + 0 ];
2097+ q8[8 *l + k + 32 ] = block[l + 4 ];
2098+ }
2099+ }
2100+ }
2101+ }
2102+ y += nb;
2103+ }
2104+ }
2105+
20202106void iqk_convert_q4_k_q8_1_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
20212107 GGML_ASSERT (n%QK_K == 0 );
20222108 GGML_ASSERT (nrc_x%8 == 0 );
@@ -2429,6 +2515,97 @@ void iqk_convert_q3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
24292515 }
24302516}
24312517
2518+ inline float convert_to_q8_k_r8 (int k, float d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
2519+ auto max_i16 = _mm256_setzero_si256 ();
2520+ __m256i qs[16 ];
2521+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
2522+ qs[2 *ib32+0 ] = _mm256_cvtepi8_epi16 (_mm256_castsi256_si128 (qx[ib32]));
2523+ qs[2 *ib32+1 ] = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (qx[ib32], 1 ));
2524+ qs[2 *ib32+0 ] = _mm256_mullo_epi16 (qs[2 *ib32+0 ], _mm256_set1_epi16 (scales[2 *ib32+0 ]));
2525+ qs[2 *ib32+1 ] = _mm256_mullo_epi16 (qs[2 *ib32+1 ], _mm256_set1_epi16 (scales[2 *ib32+1 ]));
2526+ max_i16 = _mm256_max_epi16 (max_i16, _mm256_sign_epi16 (qs[2 *ib32+0 ], qs[2 *ib32+0 ]));
2527+ max_i16 = _mm256_max_epi16 (max_i16, _mm256_sign_epi16 (qs[2 *ib32+1 ], qs[2 *ib32+1 ]));
2528+ }
2529+ auto max_q32 = _mm256_cvtepi16_epi32 (_mm_max_epi16 (_mm256_castsi256_si128 (max_i16), _mm256_extracti128_si256 (max_i16, 1 )));
2530+ auto imax4 = _mm_max_epi32 (_mm256_castsi256_si128 (max_q32), _mm256_extracti128_si256 (max_q32, 1 ));
2531+ auto max4 = _mm_cvtepi32_ps (imax4);
2532+ max4 = _mm_max_ps (max4, _mm_movehl_ps (max4, max4));
2533+ max4 = _mm_max_ss (max4, _mm_movehdup_ps (max4));
2534+ bool needs_scaling = true ;
2535+ float dnew = _mm_cvtss_f32 (max4) * d0;
2536+ if (dnew < 1 .f ) {
2537+ dnew = 1 .f ; needs_scaling = false ;
2538+ }
2539+ auto scale = _mm256_set1_ps (std::abs (dnew) > 1e-9f ? 1 /dnew : 0 .f );
2540+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
2541+ if (needs_scaling) {
2542+ auto i0 = _mm256_cvtepi16_epi32 (_mm256_castsi256_si128 (qs[2 *ib32+0 ]));
2543+ auto i1 = _mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (qs[2 *ib32+0 ], 1 ));
2544+ auto i2 = _mm256_cvtepi16_epi32 (_mm256_castsi256_si128 (qs[2 *ib32+1 ]));
2545+ auto i3 = _mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (qs[2 *ib32+1 ], 1 ));
2546+ i0 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (scale, _mm256_cvtepi32_ps (i0)), _MM_ROUND_NEAREST));
2547+ i1 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (scale, _mm256_cvtepi32_ps (i1)), _MM_ROUND_NEAREST));
2548+ i2 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (scale, _mm256_cvtepi32_ps (i2)), _MM_ROUND_NEAREST));
2549+ i3 = _mm256_cvtps_epi32 (_mm256_round_ps (_mm256_mul_ps (scale, _mm256_cvtepi32_ps (i3)), _MM_ROUND_NEAREST));
2550+ i0 = _mm256_packs_epi32 (i0, i1);
2551+ i2 = _mm256_packs_epi32 (i2, i3);
2552+ i0 = _mm256_packs_epi16 (i0, i2);
2553+ i0 = _mm256_permutevar8x32_epi32 (i0, _mm256_setr_epi32 (0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 ));
2554+ _mm256_storeu_si256 ((__m256i *)block, i0);
2555+ } else {
2556+ // 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
2557+ auto i0 = _mm256_packs_epi16 (qs[2 *ib32+0 ], qs[2 *ib32+1 ]);
2558+ auto i0_l = _mm256_castsi256_si128 (i0);
2559+ auto i0_h = _mm256_extracti128_si256 (i0, 1 );
2560+ _mm_storeu_si128 ((__m128i *)block+0 , _mm_unpacklo_epi64 (i0_l, i0_h));
2561+ _mm_storeu_si128 ((__m128i *)block+1 , _mm_unpackhi_epi64 (i0_l, i0_h));
2562+ }
2563+ auto qs = (uint32_t *)q8_k + 64 *ib32;
2564+ for (int l = 0 ; l < 8 ; ++l) {
2565+ qs[8 *l + k] = block[l];
2566+ }
2567+ }
2568+ return dnew;
2569+ }
2570+
2571+ // TODO: move this to iqk_gemm_iquants
2572+ void iqk_convert_iq4_xs_q8_k_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
2573+ GGML_ASSERT (n%QK_K == 0 );
2574+ GGML_ASSERT (nrc_x%8 == 0 );
2575+
2576+ int nb = n/QK_K;
2577+
2578+ const block_iq4_xs * x8[8 ];
2579+
2580+ block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
2581+
2582+ auto values128 = _mm_loadu_si128 ((const __m128i *)iq4k_values);
2583+ auto values = MM256_SET_M128I (values128, values128);
2584+
2585+ int16_t ls[16 ];
2586+ float dnew[8 ];
2587+ uint32_t block[8 ];
2588+ __m256i xv[8 ];
2589+
2590+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
2591+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq4_xs *)((const char *)vx + (ix + k)*bx);
2592+ for (int i = 0 ; i < nb; ++i) {
2593+ for (int k = 0 ; k < 8 ; ++k) {
2594+ float d = GGML_FP16_TO_FP32 (x8[k][i].d );
2595+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
2596+ ls[2 *ib32+0 ] = ls[2 *ib32+1 ] = (((x8[k][i].scales_l [ib32/2 ] >> 4 *(ib32%2 )) & 0xf ) | (((x8[k][i].scales_h >> 2 *ib32) & 3 ) << 4 )) - 32 ;
2597+ auto bits = _mm_loadu_si128 ((const __m128i *)x8[k][i].qs + ib32);
2598+ xv[ib32] = _mm256_and_si256 (MM256_SET_M128I (_mm_srli_epi16 (bits, 4 ), bits), _mm256_set1_epi8 (0xf ));
2599+ xv[ib32] = _mm256_shuffle_epi8 (values, xv[ib32]);
2600+ }
2601+ dnew[k] = d * convert_to_q8_k_r8 (k, 1 .f /127 , xv, ls, block, y[i].qs );
2602+ }
2603+ _mm_storeu_si128 ((__m128i *)y[i].d , _mm256_cvtps_ph (_mm256_loadu_ps (dnew), _MM_ROUND_NEAREST));
2604+ }
2605+ y += nb;
2606+ }
2607+ }
2608+
24322609
24332610} // namespace
24342611
@@ -2516,10 +2693,12 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
25162693
25172694bool iqk_convert_kquants_q8X_r8 (int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
25182695 switch (ggml_type (type)) {
2696+ case GGML_TYPE_Q2_K: iqk_convert_q2_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break ;
25192697 case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break ;
25202698 case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8 (n, vx, bx, vy, nrc_x); break ;
25212699 case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8 (n, vx, bx, vy, nrc_x); break ;
25222700 case GGML_TYPE_Q6_K: iqk_convert_q6_k_q8_0_r8 (n, vx, bx, vy, nrc_x); break ;
2701+ case GGML_TYPE_IQ4_XS: iqk_convert_iq4_xs_q8_k_r8 (n, vx, bx, vy, nrc_x); break ;
25232702 default : return false ;
25242703 }
25252704 return true ;
0 commit comments