@@ -865,6 +865,80 @@ void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info,
865865 }
866866}
867867
868+ template <int nrc_y>
869+ void mul_mat_iq1_s_q8_2_x4 (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
870+ GGML_ASSERT (n%QK_K == 0 );
871+ Q8<nrc_y, block_q8_2_x4> q8 (info);
872+ __m256i qx[4 ];
873+ __m256 scales[2 ];
874+ __m256 acc[nrc_y] = {};
875+ auto delta_mask = _mm_set1_epi16 (-32768 ); // to avoid stupid overflow warnings when using 0x8000
876+ for (int ix = 0 ; ix < nrc_x; ++ix) {
877+ auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);
878+ for (int ibl = 0 ; ibl < n/QK_K; ++ibl) {
879+ float d = GGML_FP16_TO_FP32 (iq1s[ibl].d );
880+ auto qhb = _mm_loadu_si128 ((const __m128i *)iq1s[ibl].qh );
881+ auto scales128 = _mm_and_si128 (_mm_srli_epi16 (qhb, 12 ), _mm_set1_epi16 (7 ));
882+ scales128 = _mm_add_epi16 (_mm_slli_epi16 (scales128, 1 ), _mm_set1_epi16 (1 ));
883+ auto all_scales = _mm256_mul_ps (_mm256_set1_ps (d), _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (scales128)));
884+ #ifdef HAVE_FANCY_SIMD
885+ auto mask = _mm_cmpeq_epi16_mask (_mm_and_si128 (qhb, delta_mask), delta_mask);
886+ auto deltas128 = _mm_mask_blend_epi16 (mask, _mm_set1_epi16 (-7 ), _mm_set1_epi16 (-9 ));
887+ #else
888+ auto mask = _mm_cmpeq_epi16 (_mm_and_si128 (qhb, delta_mask), delta_mask);
889+ auto deltas128 = _mm_or_si128 (_mm_and_si128 (mask, _mm_set1_epi16 (-9 )), _mm_andnot_si128 (mask, _mm_set1_epi16 (-7 )));
890+ #endif
891+ auto deltas = _mm256_mul_ps (all_scales, _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (deltas128)));
892+ for (int iy = 0 ; iy < nrc_y; ++iy) {
893+ auto my1 = _mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(q8.y [iy][2 *ibl+0 ].d + 4 )));
894+ auto my2 = _mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(q8.y [iy][2 *ibl+1 ].d + 4 )));
895+ auto my = _mm256_castsi256_ps (_mm256_slli_epi32 (MM256_SET_M128I (my2, my1), 16 ));
896+ acc[iy] = _mm256_fmadd_ps (deltas, my, acc[iy]);
897+ }
898+ all_scales = _mm256_mul_ps (_mm256_set1_ps (8 .f ), all_scales);
899+ auto scales_l = _mm256_castps256_ps128 (all_scales);
900+ auto scales_h = _mm256_extractf128_ps (all_scales, 1 );
901+ scales[0 ] = _mm256_set_m128 (scales_l, scales_l);
902+ scales[1 ] = _mm256_set_m128 (scales_h, scales_h);
903+ const uint8_t * qs = iq1s[ibl].qs ;
904+ const uint16_t * qh = iq1s[ibl].qh ;
905+ for (int i128 = 0 ; i128 < QK_K/128 ; ++i128 ) {
906+ qx[0 ] = _mm256_set_epi64x (iq1s_grid_us[qs[3 ] | ((qh[0 ] >> 1 ) & 0x700 )], iq1s_grid_us[qs[2 ] | ((qh[0 ] << 2 ) & 0x700 )],
907+ iq1s_grid_us[qs[1 ] | ((qh[0 ] << 5 ) & 0x700 )], iq1s_grid_us[qs[0 ] | ((qh[0 ] << 8 ) & 0x700 )]);
908+ qx[1 ] = _mm256_set_epi64x (iq1s_grid_us[qs[7 ] | ((qh[1 ] >> 1 ) & 0x700 )], iq1s_grid_us[qs[6 ] | ((qh[1 ] << 2 ) & 0x700 )],
909+ iq1s_grid_us[qs[5 ] | ((qh[1 ] << 5 ) & 0x700 )], iq1s_grid_us[qs[4 ] | ((qh[1 ] << 8 ) & 0x700 )]);
910+ qs += 8 ;
911+ qx[2 ] = _mm256_set_epi64x (iq1s_grid_us[qs[3 ] | ((qh[2 ] >> 1 ) & 0x700 )], iq1s_grid_us[qs[2 ] | ((qh[2 ] << 2 ) & 0x700 )],
912+ iq1s_grid_us[qs[1 ] | ((qh[2 ] << 5 ) & 0x700 )], iq1s_grid_us[qs[0 ] | ((qh[2 ] << 8 ) & 0x700 )]);
913+ qx[3 ] = _mm256_set_epi64x (iq1s_grid_us[qs[7 ] | ((qh[3 ] >> 1 ) & 0x700 )], iq1s_grid_us[qs[6 ] | ((qh[3 ] << 2 ) & 0x700 )],
914+ iq1s_grid_us[qs[5 ] | ((qh[3 ] << 5 ) & 0x700 )], iq1s_grid_us[qs[4 ] | ((qh[3 ] << 8 ) & 0x700 )]);
915+ qs += 8 ; qh += 4 ;
916+ for (int iy = 0 ; iy < nrc_y; ++iy) {
917+ auto & ybl = q8.y [iy][2 *ibl+i128 ];
918+ auto sumi1 = _mm256_maddubs_epi16 (qx[0 ], _mm256_loadu_si256 ((const __m256i *)ybl.qs +0 ));
919+ auto sumi2 = _mm256_maddubs_epi16 (qx[1 ], _mm256_loadu_si256 ((const __m256i *)ybl.qs +1 ));
920+ auto sumi3 = _mm256_maddubs_epi16 (qx[2 ], _mm256_loadu_si256 ((const __m256i *)ybl.qs +2 ));
921+ auto sumi4 = _mm256_maddubs_epi16 (qx[3 ], _mm256_loadu_si256 ((const __m256i *)ybl.qs +3 ));
922+ // 0,0,1,1, 0,0,1,1, 0,0,1,1, 0,0,1,1 as int16_t
923+ sumi1 = _mm256_add_epi16 (_mm256_unpacklo_epi32 (sumi1, sumi2), _mm256_unpackhi_epi32 (sumi1, sumi2));
924+ // 2,2,3,3, 2,2,3,3, 2,2,3,3, 2,2,3,3 as int16_t
925+ sumi3 = _mm256_add_epi16 (_mm256_unpacklo_epi32 (sumi3, sumi4), _mm256_unpackhi_epi32 (sumi3, sumi4));
926+ sumi1 = _mm256_add_epi16 (_mm256_unpacklo_epi64 (sumi1, sumi3), _mm256_unpackhi_epi64 (sumi1, sumi3));
927+ // 0, 1, 2, 3, 0, 1, 2, 3 as int322_t
928+ sumi1 = _mm256_madd_epi16 (_mm256_set1_epi16 (1 ), sumi1);
929+ auto d4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)ybl.d )), 16 ));
930+ auto dy = _mm256_set_m128 (d4, d4);
931+ acc[iy] = _mm256_fmadd_ps (_mm256_mul_ps (scales[i128 ], dy), _mm256_cvtepi32_ps (sumi1), acc[iy]);
932+ }
933+ }
934+ }
935+ for (int iy = 0 ; iy < nrc_y; ++iy) {
936+ info.store (ix, iy, 0 .125f *hsum_float_8 (acc[iy]));
937+ acc[iy] = _mm256_setzero_ps ();
938+ }
939+ }
940+ }
941+
868942template <int nrc_y>
869943static void mul_mat_iq1_s_r4_q8_1 (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
870944 GGML_ASSERT (nrc_x%4 == 0 );
@@ -1533,23 +1607,79 @@ static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const Da
15331607}
15341608#endif
15351609
1610+ void iqk_convert_iq1_s_q8_0_r8 (int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1611+ GGML_ASSERT (n%QK_K == 0 );
1612+ GGML_ASSERT (nrc_x%8 == 0 );
1613+
1614+ int nb = n/QK_K;
1615+
1616+ const block_iq1_s * x8[8 ];
1617+
1618+ block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
1619+
1620+ ggml_half dh[8 ];
1621+ uint16_t all_ls[64 ];
1622+
1623+ uint32_t block[8 ];
1624+
1625+ for (int ix = 0 ; ix < nrc_x; ix += 8 ) {
1626+ for (int k = 0 ; k < 8 ; ++k) x8[k] = (const block_iq1_s *)((const char *)vx + (ix + k)*bx);
1627+ for (int i = 0 ; i < nb; ++i) {
1628+ for (int k = 0 ; k < 8 ; ++k) {
1629+ dh[k] = x8[k][i].d ;
1630+ auto qs = x8[k][i].qs ;
1631+ auto qh = x8[k][i].qh ;
1632+ __m256i value;
1633+ for (int ib32 = 0 ; ib32 < 8 ; ++ib32) {
1634+ all_ls[8 *ib32 + k] = (2 *((qh[ib32] >> 12 ) & 7 ) + 1 );
1635+ value = _mm256_set_epi64x (iq1s_grid[qs[3 ] | ((qh[ib32] >> 1 ) & 0x700 )], iq1s_grid[qs[2 ] | ((qh[ib32] << 2 ) & 0x700 )],
1636+ iq1s_grid[qs[1 ] | ((qh[ib32] << 5 ) & 0x700 )], iq1s_grid[qs[0 ] | ((qh[ib32] << 8 ) & 0x700 )]);
1637+ value = _mm256_slli_epi16 (_mm256_add_epi8 (value, _mm256_set1_epi8 (1 )), 3 );
1638+ int8_t delta = qh[ib32] & 0x8000 ? -9 : -7 ;
1639+ value = _mm256_add_epi8 (value, _mm256_set1_epi8 (delta));
1640+ _mm256_storeu_si256 ((__m256i *)block, value);
1641+ auto q8 = (uint32_t *)y[ib32].qs ;
1642+ for (int l = 0 ; l < 4 ; ++l) {
1643+ q8[8 *l + k + 0 ] = block[l + 0 ];
1644+ q8[8 *l + k + 32 ] = block[l + 4 ];
1645+ }
1646+ qs += 4 ;
1647+ }
1648+ }
1649+ auto vd = _mm256_mul_ps (_mm256_set1_ps (0 .125f ), _mm256_cvtph_ps (_mm_loadu_si128 ((const __m128i *)dh)));
1650+ for (int ib32 = 0 ; ib32 < QK_K/32 ; ++ib32) {
1651+ auto iscales16 = _mm_loadu_si128 ((const __m128i *)all_ls + ib32);
1652+ auto iscales32 = _mm256_cvtepi16_epi32 (iscales16);
1653+ auto scales = _mm256_mul_ps (vd, _mm256_cvtepi32_ps (iscales32));
1654+ _mm_storeu_si128 ((__m128i *)y[ib32].d , _mm256_cvtps_ph (scales, _MM_FROUND_TO_NEAREST_INT));
1655+ }
1656+ y += QK_K/32 ;
1657+ }
1658+ }
1659+ }
15361660
15371661} // namespace
15381662
15391663bool iqk_set_kernels_1bit (int ne00, int typeA, int typeB, std::array<mul_mat_t , IQK_MAX_NY>& funcs, mul_mat_t & func16) {
15401664
15411665 auto expected_typeB = GGML_TYPE_Q8_K128;
1666+ auto actual_typeB = ggml_type (typeB);
15421667
15431668 func16 = nullptr ;
15441669
15451670 switch (typeA) {
15461671 case GGML_TYPE_IQ1_S:
15471672 if (ne00%QK_K != 0 ) return false ;
1548- IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_iq1_s_q8_K, funcs);
1673+ if (actual_typeB == GGML_TYPE_Q8_2_X4) {
1674+ IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_iq1_s_q8_2_x4, funcs);
1675+ expected_typeB = GGML_TYPE_Q8_2_X4;
1676+ } else {
1677+ IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_iq1_s_q8_K, funcs);
15491678#ifdef HAVE_FANCY_SIMD
1550- func16 = mul_mat_iq1_s_q8_K<16 >;
1679+ func16 = mul_mat_iq1_s_q8_K<16 >;
15511680#endif
1552- expected_typeB = GGML_TYPE_Q8_K;
1681+ expected_typeB = GGML_TYPE_Q8_K;
1682+ }
15531683 break ;
15541684 case GGML_TYPE_IQ1_S_R4:
15551685 if (ne00%128 != 0 ) return false ;
@@ -1585,8 +1715,17 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
15851715 return false ;
15861716 }
15871717
1588- return ggml_type (typeB) == expected_typeB;
1718+ return actual_typeB == expected_typeB;
1719+
1720+ }
15891721
1722+ bool iqk_convert_1bit_q80_r8 (int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1723+ if (n%QK_K != 0 || nrc_x%8 != 0 ) return false ;
1724+ switch (ggml_type (type)) {
1725+ case GGML_TYPE_IQ1_S: iqk_convert_iq1_s_q8_0_r8 (n, vx, bx, vy, nrc_x); break ;
1726+ default : return false ;
1727+ }
1728+ return true ;
15901729}
15911730
15921731#else
@@ -2277,6 +2416,10 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
22772416
22782417}
22792418
2419+ bool iqk_convert_1bit_q80_r8 ([[maybe_unused]] int type, [[maybe_unused]] int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, [[maybe_unused]] int nrc_x) {
2420+ return false ;
2421+ }
2422+
22802423#endif
22812424
22822425#endif
0 commit comments