@@ -472,6 +472,15 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
472472 auto scales16 = prepare_scales (i);
473473 scales[0 ] = MM256_SET_M128I (scales16, scales16);
474474 }
475+ inline void new_block_f (int i, __m256 * scales) {
476+ auto sc16 = prepare_scales (i);
477+ auto scf = _mm256_mul_ps (_mm256_set1_ps (d), _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (sc16)));
478+ auto scf_l = _mm256_castps256_ps128 (scf);
479+ auto scf_h = _mm256_extractf128_ps (scf, 1 );
480+ scales[0 ] = _mm256_set_m128 (scf_l, scf_l);
481+ scales[1 ] = _mm256_set_m128 (scf_h, scf_h);
482+ scales[2 ] = _mm256_mul_ps (scf, _mm256_set1_ps (-minv));
483+ }
475484 inline float new_block (int i, __m256i * scales, __m256i& mins) {
476485 auto scales16 = prepare_scales (i);
477486 mins = scb.shuffle (scales16);
@@ -1771,6 +1780,58 @@ void iqk_convert_iq2_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, i
17711780 }
17721781}
17731782
1783+ void iqk_convert_iq3_xxs_q8_0_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1784+ GGML_ASSERT (n%QK_K == 0 );
1785+ GGML_ASSERT (nrc_x%8 == 0 );
1786+
1787+ int nb = n/QK_K;
1788+
1789+ const block_iq3_xxs * x8[8 ];
1790+
1791+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
1792+
1793+ ggml_half dh[8 ];
1794+ uint16_t all_ls[64 ];
1795+ EvenSignHelper esh;
1796+
1797+ uint32_t block[8 ];
1798+ uint32_t aux32;
1799+
1800+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
1801+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq3_xxs *)((const char *)vx + (ix + k)*bx);
1802+ for (int i = 0 ; i < nb; ++i) {
1803+ // TODO: simdify
1804+ for (int k = 0 ; k < 8 ; ++k) {
1805+ dh[k] = x8[k][i].d ;
1806+ auto qs = x8[k][i].qs ;
1807+ auto sas = qs + QK_K/4 ;
1808+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
1809+ std::memcpy (&aux32, sas + 4 *ib32, sizeof (uint32_t ));
1810+ all_ls[8 *ib32 + k] = (2 *(aux32 >> 28 ) + 1 );
1811+ auto value = _mm256_set_epi32 (iq3xxs_grid[qs[7 ]], iq3xxs_grid[qs[6 ]], iq3xxs_grid[qs[5 ]], iq3xxs_grid[qs[4 ]],
1812+ iq3xxs_grid[qs[3 ]], iq3xxs_grid[qs[2 ]], iq3xxs_grid[qs[1 ]], iq3xxs_grid[qs[0 ]]);
1813+ esh.sign_value (aux32, value);
1814+ _mm256_storeu_si256 ((__m256i *)block, value);
1815+ auto q8 = (uint32_t *)y[ib32].qs ;
1816+ for (int l = 0 ; l < 4 ; ++l) {
1817+ q8[8 *l + k + 0 ] = block[l + 0 ];
1818+ q8[8 *l + k + 32 ] = block[l + 4 ];
1819+ }
1820+ qs += 8 ;
1821+ }
1822+ }
1823+ auto vd = _mm256_mul_ps (_mm256_set1_ps (0 .25f ), _mm256_cvtph_ps (_mm_loadu_si128 ((const __m128i *)dh)));
1824+ for (int ib32 = 0 ; ib32 < QK_K/32 ; ++ib32) {
1825+ auto iscales16 = _mm_loadu_si128 ((const __m128i *)all_ls + ib32);
1826+ auto iscales32 = _mm256_cvtepi16_epi32 (iscales16);
1827+ auto scales = _mm256_mul_ps (vd, _mm256_cvtepi32_ps (iscales32));
1828+ _mm_storeu_si128 ((__m128i *)y[ib32].d , _mm256_cvtps_ph (scales, _MM_FROUND_TO_NEAREST_INT));
1829+ }
1830+ y += QK_K/32 ;
1831+ }
1832+ }
1833+ }
1834+
17741835template <typename Dequantizer> void set_functions (std::array<mul_mat_t , IQK_MAX_NY>& funcs) {
17751836 funcs[0 ] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1 >;
17761837 funcs[1 ] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2 >;
@@ -1791,7 +1852,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
17911852 if (ggml_type (typeA) == GGML_TYPE_IQ2_XXS) {
17921853 if (ggml_type (typeB) == GGML_TYPE_Q8_2_X4) {
17931854 IQK_SET_MUL_MAT_FUNCTIONS_T (mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ2XXS, kernels);
1794- // IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_q8_2_IQ_N, kernels);
1855+ func16 = nullptr ;
1856+ return true ;
1857+ }
1858+ return false ;
1859+ }
1860+
1861+ if (ggml_type (typeA) == GGML_TYPE_IQ3_XXS) {
1862+ if (ggml_type (typeB) == GGML_TYPE_Q8_2_X4) {
1863+ IQK_SET_MUL_MAT_FUNCTIONS_T (mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3XXS, kernels);
17951864 func16 = nullptr ;
17961865 return true ;
17971866 }
@@ -1856,6 +1925,7 @@ bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, voi
18561925 if (n%QK_K != 0 || nrc_x%8 != 0 ) return false ;
18571926 switch (ggml_type (type)) {
18581927 case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_0_r8 (n, vx, bx, vy, nrc_x); break ;
1928+ case GGML_TYPE_IQ3_XXS: iqk_convert_iq3_xxs_q8_0_r8 (n, vx, bx, vy, nrc_x); break ;
18591929 default : return false ;
18601930 }
18611931 return true ;
0 commit comments