@@ -865,6 +865,68 @@ void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info,
865865 }
866866}
867867
868+ template <int nrc_y>
869+ void mul_mat_iq1_m_q8_K (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
870+ GGML_ASSERT (n%QK_K == 0 );
871+ Q8<nrc_y, block_q8_K> q8 (info);
872+ __m256i qx[8 ];
873+ __m256 acc[nrc_y] = {};
874+ auto scale_shuffle = _mm256_set_epi64x (0x0706070607060706 , 0x0504050405040504 , 0x0302030203020302 , 0x0100010001000100 );
875+ auto delta_mask = _mm256_set_epi64x (0x8000 , 0x0800 , 0x0080 , 0x0008 );
876+ iq1m_scale_t scale;
877+ union { __m256i vec; int16_t val[16 ]; } helper;
878+ for (int ix = 0 ; ix < nrc_x; ++ix) {
879+ auto iq1m = (const block_iq1_m *)((const char *)vx + ix*bx);
880+ for (int ibl = 0 ; ibl < n/QK_K; ++ibl) {
881+ const uint16_t * sc = (const uint16_t *)iq1m[ibl].scales ; // 4 x uint16_t, each containing 4 scales
882+ scale.u16 = (sc[0 ] >> 12 ) | ((sc[1 ] >> 8 ) & 0x00f0 ) | ((sc[2 ] >> 4 ) & 0x0f00 ) | (sc[3 ] & 0xf000 );
883+ float d = GGML_FP16_TO_FP32 (scale.f16 );
884+ auto qs = iq1m[ibl].qs ;
885+ auto qh = iq1m[ibl].qh ;
886+ auto aux = _mm_loadl_epi64 ((const __m128i *)iq1m[ibl].scales );
887+ auto sc16 = _mm256_shuffle_epi8 (MM256_SET_M128I (aux, aux), scale_shuffle);
888+ sc16 = _mm256_and_si256 (sc16, _mm256_set1_epi64x (0x0e0001c000380007 ));
889+ sc16 = _mm256_mullo_epi16 (sc16, _mm256_set1_epi64x (0x0001000800400200 ));
890+ helper.vec = _mm256_add_epi8 (_mm256_srli_epi16 (sc16, 8 ), _mm256_set1_epi16 (1 ));
891+ for (int ib64 = 0 ; ib64 < QK_K/64 ; ++ib64) {
892+ qx[2 *ib64+0 ] = _mm256_set_epi64x (iq1s_grid_us[qs[3 ] | (((uint16_t )qh[1 ] << 4 ) & 0x700 )], iq1s_grid_us[qs[2 ] | (((uint16_t )qh[1 ] << 8 ) & 0x700 )],
893+ iq1s_grid_us[qs[1 ] | (((uint16_t )qh[0 ] << 4 ) & 0x700 )], iq1s_grid_us[qs[0 ] | (((uint16_t )qh[0 ] << 8 ) & 0x700 )]);
894+ qx[2 *ib64+1 ] = _mm256_set_epi64x (iq1s_grid_us[qs[7 ] | (((uint16_t )qh[3 ] << 4 ) & 0x700 )], iq1s_grid_us[qs[6 ] | (((uint16_t )qh[3 ] << 8 ) & 0x700 )],
895+ iq1s_grid_us[qs[5 ] | (((uint16_t )qh[2 ] << 4 ) & 0x700 )], iq1s_grid_us[qs[4 ] | (((uint16_t )qh[2 ] << 8 ) & 0x700 )]);
896+ // auto delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0x0909090909090909 : 0x0707070707070707,
897+ // qh[1] & 0x08 ? 0x0909090909090909 : 0x0707070707070707,
898+ // qh[0] & 0x80 ? 0x0909090909090909 : 0x0707070707070707,
899+ // qh[0] & 0x08 ? 0x0909090909090909 : 0x0707070707070707);
900+ // auto delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0x0909090909090909 : 0x0707070707070707,
901+ // qh[3] & 0x08 ? 0x0909090909090909 : 0x0707070707070707,
902+ // qh[2] & 0x80 ? 0x0909090909090909 : 0x0707070707070707,
903+ // qh[2] & 0x08 ? 0x0909090909090909 : 0x0707070707070707);
904+ auto qh16 = (const uint16_t *)qh;
905+ auto delta1 = _mm256_cmpeq_epi64 (_mm256_and_si256 (_mm256_set1_epi64x (qh16[0 ]), delta_mask), delta_mask);
906+ auto delta2 = _mm256_cmpeq_epi64 (_mm256_and_si256 (_mm256_set1_epi64x (qh16[1 ]), delta_mask), delta_mask);
907+ delta1 = _mm256_sub_epi8 (_mm256_set1_epi8 (8 ), _mm256_or_si256 (delta1, _mm256_set1_epi8 (1 )));
908+ delta2 = _mm256_sub_epi8 (_mm256_set1_epi8 (8 ), _mm256_or_si256 (delta2, _mm256_set1_epi8 (1 )));
909+ qx[2 *ib64+0 ] = _mm256_sub_epi8 (_mm256_slli_epi16 (qx[2 *ib64+0 ], 3 ), delta1);
910+ qx[2 *ib64+1 ] = _mm256_sub_epi8 (_mm256_slli_epi16 (qx[2 *ib64+1 ], 3 ), delta2);
911+ qs += 8 ;
912+ qh += 4 ;
913+ }
914+ for (int iy = 0 ; iy < nrc_y; ++iy) {
915+ auto sumi = _mm256_setzero_si256 ();
916+ for (int j = 0 ; j < 8 ; ++j) {
917+ auto p = _mm256_maddubs_epi16 (_mm256_sign_epi8 (qx[j], qx[j]), _mm256_sign_epi8 (q8.load_quants (iy, ibl, j), qx[j]));
918+ sumi = _mm256_add_epi32 (sumi, _mm256_madd_epi16 (p, MM256_SET_M128I (_mm_set1_epi16 (helper.val [2 *j+1 ]), _mm_set1_epi16 (helper.val [2 *j+0 ]))));
919+ }
920+ acc[iy] = _mm256_fmadd_ps (_mm256_set1_ps (d*q8.scale (iy, ibl)), _mm256_cvtepi32_ps (sumi), acc[iy]);
921+ }
922+ }
923+ for (int iy = 0 ; iy < nrc_y; ++iy) {
924+ info.store (ix, iy, 0 .125f *hsum_float_8 (acc[iy]));
925+ acc[iy] = _mm256_setzero_ps ();
926+ }
927+ }
928+ }
929+
868930template <int nrc_y>
869931void mul_mat_iq1_s_q8_2_x4 (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
870932 GGML_ASSERT (n%QK_K == 0 );
@@ -1844,6 +1906,11 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
18441906 func16 = mul_mat_iq1_s_r4_q8_1<16 >;
18451907#endif
18461908 break ;
1909+ case GGML_TYPE_IQ1_M:
1910+ if (ne00%QK_K != 0 ) return false ;
1911+ IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_iq1_m_q8_K, funcs);
1912+ expected_typeB = GGML_TYPE_Q8_K;
1913+ break ;
18471914 case GGML_TYPE_IQ1_M_R4:
18481915 if (ne00%128 != 0 ) return false ;
18491916 IQK_SET_MUL_MAT_FUNCTIONS (mul_mat_iq1_m_r4_q8_0, funcs);
0 commit comments