@@ -145,35 +145,6 @@ struct SignHelper {
145145 const __m256i mone = _mm256_set1_epi8(1 );
146146};
147147
148- // for (int i = 0; i < nb; ++i) {
149- //
150- // __m256i sumi[nrc_y], all_scales;
151- // //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
152- // __m256i mins;
153- // float dmin = deq.new_block(i, &all_scales, mins);
154- // for (int iy = 0; iy < nrc_y; ++iy) {
155- // auto bsums = q8.load_bsums(iy, i);
156- // auto prod = _mm256_madd_epi16(mins, bsums);
157- // accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
158- // }
159- //
160- // for (int j = 0; j < QK_K/128; ++j) {
161- // deq.prepare(i, j);
162- // set_scales_8(&all_scales, j, scales);
163- // //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
164- // multiply_add(deq.bits, scales, j, i, q8, sumi);
165- // }
166- // for (int iy = 0; iy < nrc_y; ++iy) {
167- // const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
168- // accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
169- // }
170- // }
171- //
172- // for (int iy = 0; iy < nrc_y; ++iy) {
173- // info.store(ix, iy, hsum_float_8(accd[iy]));
174- // }
175- // }
176-
177148struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
178149 DequantizerIQ2XXS (const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
179150
@@ -221,7 +192,7 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
221192 }
222193
223194 IQK_ALWAYS_INLINE void sign_values (const uint32_t * aux32, __m256i * values) const {
224- #if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
195+ #if defined z_HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
225196 esh.sign_2_values (MM256_SET_M128I (_mm_set1_epi32 (aux32[3 ]), _mm_set1_epi32 (aux32[1 ])), values+0 );
226197 esh.sign_2_values (MM256_SET_M128I (_mm_set1_epi32 (aux32[7 ]), _mm_set1_epi32 (aux32[5 ])), values+2 );
227198#else
@@ -246,7 +217,11 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
246217 }
247218 inline void prepare (int i, int j, const Q8<1 >& q8, __m256i * q8_quants) {
248219 for (int k = 0 ; k < 4 ; ++k) q8_quants[k] = q8.load_quants (0 , i, 4 *j+k);
249- Data data; data.vec = _mm256_loadu_si256 ((const __m256i *)x[i].qs + j);
220+ Data data; data.vec = _mm256_loadu_si256 ((const __m256i *)x[i].qs + j);
221+ make4 (data.val , bits.values , q8_quants);
222+ }
223+ inline void prepare (int i, int j, __m256i * q8_quants) {
224+ Data data; data.vec = _mm256_loadu_si256 ((const __m256i *)x[i].qs + j);
250225 make4 (data.val , bits.values , q8_quants);
251226 }
252227
@@ -526,6 +501,13 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
526501 sign_2_values (signs+0 , q8_quants+0 );
527502 sign_2_values (signs+4 , q8_quants+2 );
528503 }
504+ inline void prepare (int i, int j, __m256i * q8_quants) {
505+ auto qs = x[i].qs + 32 *j;
506+ const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4 ) + 8 *j;
507+ make4_unsigned (qs, bits.values );
508+ sign_2_values (signs+0 , q8_quants+0 );
509+ sign_2_values (signs+4 , q8_quants+2 );
510+ }
529511
530512 constexpr static int minv = 64 ;
531513
@@ -625,6 +607,10 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
625607 for (int k = 0 ; k < 4 ; ++k) q8_quants[k] = q8.load_quants (0 , i, 4 *j+k);
626608 sh.sign_4_values ((const uint16_t *)x[i].signs + 8 *j, q8_quants);
627609 }
610+ inline void prepare (int i, int j, __m256i * q8_quants) {
611+ prepare_unsigned (i, j);
612+ sh.sign_4_values ((const uint16_t *)x[i].signs + 8 *j, q8_quants);
613+ }
628614
629615 inline void prepare_unsigned (int i, int j) {
630616 auto qs = x[i].qs + 32 *j;
@@ -787,15 +773,69 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
787773 }
788774}
789775
790- template <typename Dequantizer, int nrc_y>
776+ template <int n_sum>
777+ inline __m256i compute_dot_4 (const __m256i * x, const __m256i * y) {
778+ #ifdef HAVE_FANCY_SIMD
779+ auto sumi0 = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), x[0 ], y[0 ]);
780+ auto sumi1 = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), x[1 ], y[1 ]);
781+ auto sumi2 = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), x[2 ], y[2 ]);
782+ auto sumi3 = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), x[3 ], y[3 ]);
783+ sumi0 = _mm256_add_epi32 (_mm256_unpacklo_epi32 (sumi0, sumi1), _mm256_unpackhi_epi32 (sumi0, sumi1));
784+ sumi2 = _mm256_add_epi32 (_mm256_unpacklo_epi32 (sumi2, sumi3), _mm256_unpackhi_epi32 (sumi2, sumi3));
785+ return _mm256_add_epi32 (_mm256_unpacklo_epi64 (sumi0, sumi2), _mm256_unpackhi_epi64 (sumi0, sumi2));
786+ #else
787+ auto m1 = _mm256_set1_epi16 (1 );
788+ if constexpr (n_sum == 2 ) {
789+ auto sumi0 = _mm256_madd_epi16 (m1, _mm256_maddubs_epi16 (x[0 ], y[0 ]));
790+ auto sumi1 = _mm256_madd_epi16 (m1, _mm256_maddubs_epi16 (x[1 ], y[1 ]));
791+ auto sumi2 = _mm256_madd_epi16 (m1, _mm256_maddubs_epi16 (x[2 ], y[2 ]));
792+ auto sumi3 = _mm256_madd_epi16 (m1, _mm256_maddubs_epi16 (x[3 ], y[3 ]));
793+ sumi0 = _mm256_add_epi32 (_mm256_unpacklo_epi32 (sumi0, sumi1), _mm256_unpackhi_epi32 (sumi0, sumi1));
794+ sumi2 = _mm256_add_epi32 (_mm256_unpacklo_epi32 (sumi2, sumi3), _mm256_unpackhi_epi32 (sumi2, sumi3));
795+ return _mm256_add_epi32 (_mm256_unpacklo_epi64 (sumi0, sumi2), _mm256_unpackhi_epi64 (sumi0, sumi2));
796+ }
797+ else {
798+ auto sumi0 = _mm256_maddubs_epi16 (x[0 ], y[0 ]);
799+ auto sumi1 = _mm256_maddubs_epi16 (x[1 ], y[1 ]);
800+ auto sumi2 = _mm256_maddubs_epi16 (x[2 ], y[2 ]);
801+ auto sumi3 = _mm256_maddubs_epi16 (x[3 ], y[3 ]);
802+ if constexpr (n_sum == 4 ) {
803+ sumi0 = _mm256_add_epi16 (_mm256_unpacklo_epi32 (sumi0, sumi1), _mm256_unpackhi_epi32 (sumi0, sumi1));
804+ sumi2 = _mm256_add_epi16 (_mm256_unpacklo_epi32 (sumi2, sumi3), _mm256_unpackhi_epi32 (sumi2, sumi3));
805+ sumi0 = _mm256_madd_epi16 (m1, sumi0);
806+ sumi2 = _mm256_madd_epi16 (m1, sumi2);
807+ return _mm256_add_epi32 (_mm256_unpacklo_epi64 (sumi0, sumi2), _mm256_unpackhi_epi64 (sumi0, sumi2));
808+ }
809+ else {
810+ auto sumi0 = _mm256_maddubs_epi16 (x[0 ], y[0 ]);
811+ auto sumi1 = _mm256_maddubs_epi16 (x[1 ], y[1 ]);
812+ auto sumi2 = _mm256_maddubs_epi16 (x[2 ], y[2 ]);
813+ auto sumi3 = _mm256_maddubs_epi16 (x[3 ], y[3 ]);
814+ sumi0 = _mm256_add_epi16 (_mm256_unpacklo_epi32 (sumi0, sumi1), _mm256_unpackhi_epi32 (sumi0, sumi1));
815+ sumi2 = _mm256_add_epi16 (_mm256_unpacklo_epi32 (sumi2, sumi3), _mm256_unpackhi_epi32 (sumi2, sumi3));
816+ sumi0 = _mm256_add_epi16 (_mm256_unpacklo_epi64 (sumi0, sumi2), _mm256_unpackhi_epi64 (sumi0, sumi2));
817+ return _mm256_madd_epi16 (m1, sumi0);
818+ }
819+ }
820+ #endif
821+ }
822+
823+ template <typename Dequantizer, int nrc_y, int n_sum = 2 >
791824static void mul_mat_qX_K_q8_2_IQ_N (int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
792825 static_assert (Dequantizer::num_blocks == 8 );
826+ static_assert (n_sum == 2 || n_sum == 4 || n_sum == 8 );
827+ #ifdef HAVE_FANCY_SIMD
828+ constexpr bool use_1_row = nrc_y == 1 ;
829+ #else
830+ constexpr bool use_1_row = nrc_y == 1 && !std::is_same_v<Dequantizer, DequantizerIQ2XXS>;
831+ #endif
832+
793833 const int nb = n / QK_K;
794834 Q8<nrc_y, block_q8_2_x4> q8 (info);
795835 Dequantizer deq (vx, bx);
796836 __m256 scales[3 ];
797837 __m256 accd[nrc_y];
798- __m256i sumi [4 ];
838+ __m256i vy [4 ];
799839
800840 for (int ix = 0 ; ix < nrc_x; ++ix) {
801841
@@ -806,35 +846,33 @@ static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const Data
806846 for (int i = 0 ; i < nb; ++i) {
807847
808848 deq.new_block_f (i, scales);
809- for (int iy = 0 ; iy < nrc_y; ++iy) {
810- auto my1 = _mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(q8.y [iy][2 *i+0 ].d + 4 )));
811- auto my2 = _mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(q8.y [iy][2 *i+1 ].d + 4 )));
812- auto my = _mm256_castsi256_ps (_mm256_slli_epi32 (MM256_SET_M128I (my2, my1), 16 ));
813- accd[iy] = _mm256_fmadd_ps (scales[2 ], my, accd[iy]);
849+ if constexpr (!use_1_row) {
850+ for (int iy = 0 ; iy < nrc_y; ++iy) {
851+ auto my1 = _mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(q8.y [iy][2 *i+0 ].d + 4 )));
852+ auto my2 = _mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)(q8.y [iy][2 *i+1 ].d + 4 )));
853+ auto my = _mm256_castsi256_ps (_mm256_slli_epi32 (MM256_SET_M128I (my2, my1), 16 ));
854+ accd[iy] = _mm256_fmadd_ps (scales[2 ], my, accd[iy]);
855+ }
814856 }
815857
816858 for (int j = 0 ; j < QK_K/128 ; ++j) {
817- deq.prepare (i, j);
818- auto & values = deq.bits .values ;
819- for (int iy = 0 ; iy < nrc_y; ++iy) {
820- auto qs = q8.y [iy][2 *i+j].qs ;
821- #ifdef HAVE_FANCY_SIMD
822- sumi[0 ] = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), values[0 ], _mm256_loadu_si256 ((const __m256i*)qs+0 ));
823- sumi[1 ] = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), values[1 ], _mm256_loadu_si256 ((const __m256i*)qs+1 ));
824- sumi[2 ] = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), values[2 ], _mm256_loadu_si256 ((const __m256i*)qs+2 ));
825- sumi[3 ] = _mm256_dpbusd_epi32 (_mm256_setzero_si256 (), values[3 ], _mm256_loadu_si256 ((const __m256i*)qs+3 ));
826- #else
827- sumi[0 ] = _mm256_madd_epi16 (_mm256_set1_epi16 (1 ), _mm256_maddubs_epi16 (values[0 ], _mm256_loadu_si256 ((const __m256i*)qs+0 )));
828- sumi[1 ] = _mm256_madd_epi16 (_mm256_set1_epi16 (1 ), _mm256_maddubs_epi16 (values[1 ], _mm256_loadu_si256 ((const __m256i*)qs+1 )));
829- sumi[2 ] = _mm256_madd_epi16 (_mm256_set1_epi16 (1 ), _mm256_maddubs_epi16 (values[2 ], _mm256_loadu_si256 ((const __m256i*)qs+2 )));
830- sumi[3 ] = _mm256_madd_epi16 (_mm256_set1_epi16 (1 ), _mm256_maddubs_epi16 (values[3 ], _mm256_loadu_si256 ((const __m256i*)qs+3 )));
831- #endif
832- sumi[0 ] = _mm256_add_epi32 (_mm256_unpacklo_epi32 (sumi[0 ], sumi[1 ]), _mm256_unpackhi_epi32 (sumi[0 ], sumi[1 ]));
833- sumi[2 ] = _mm256_add_epi32 (_mm256_unpacklo_epi32 (sumi[2 ], sumi[3 ]), _mm256_unpackhi_epi32 (sumi[2 ], sumi[3 ]));
834- sumi[0 ] = _mm256_add_epi32 (_mm256_unpacklo_epi64 (sumi[0 ], sumi[2 ]), _mm256_unpackhi_epi64 (sumi[0 ], sumi[2 ]));
835- auto d4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)q8.y [iy][2 *i+j].d )), 16 ));
859+ if constexpr (use_1_row) {
860+ for (int k = 0 ; k < 4 ; ++k) vy[k] = _mm256_loadu_si256 ((const __m256i*)q8.y [0 ][2 *i+j].qs +k);
861+ deq.prepare (i, j, vy);
862+ auto sumi = compute_dot_4<2 *n_sum>(deq.bits .values , vy);
863+ auto d4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)q8.y [0 ][2 *i+j].d )), 16 ));
836864 auto dy = _mm256_set_m128 (d4, d4);
837- accd[iy] = _mm256_fmadd_ps (_mm256_mul_ps (scales[j], dy), _mm256_cvtepi32_ps (sumi[0 ]), accd[iy]);
865+ accd[0 ] = _mm256_fmadd_ps (_mm256_mul_ps (scales[j], dy), _mm256_cvtepi32_ps (sumi), accd[0 ]);
866+ } else {
867+ deq.prepare (i, j);
868+ for (int iy = 0 ; iy < nrc_y; ++iy) {
869+ auto qs = q8.y [iy][2 *i+j].qs ;
870+ for (int k = 0 ; k < 4 ; ++k) vy[k] = _mm256_loadu_si256 ((const __m256i*)qs+k);
871+ auto sumi = compute_dot_4<n_sum>(deq.bits .values , vy);
872+ auto d4 = _mm_castsi128_ps (_mm_slli_epi32 (_mm_cvtepu16_epi32 (_mm_loadl_epi64 ((const __m128i *)q8.y [iy][2 *i+j].d )), 16 ));
873+ auto dy = _mm256_set_m128 (d4, d4);
874+ accd[iy] = _mm256_fmadd_ps (_mm256_mul_ps (scales[j], dy), _mm256_cvtepi32_ps (sumi), accd[iy]);
875+ }
838876 }
839877 }
840878 }
@@ -1934,7 +1972,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
19341972
19351973 if (ggml_type (typeA) == GGML_TYPE_IQ3_S) {
19361974 if (ggml_type (typeB) == GGML_TYPE_Q8_2_X4) {
1937- IQK_SET_MUL_MAT_FUNCTIONS_T (mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels);
1975+ // IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels);
1976+ kernels[0 ] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 1 , 8 >;
1977+ kernels[1 ] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 2 , 8 >;
1978+ kernels[2 ] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 3 , 8 >;
1979+ kernels[3 ] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 4 , 8 >;
1980+ kernels[4 ] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 5 , 8 >;
1981+ kernels[5 ] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 6 , 8 >;
1982+ kernels[6 ] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 7 , 8 >;
1983+ kernels[7 ] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 8 , 8 >;
19381984 func16 = nullptr ;
19391985 return true ;
19401986 }
0 commit comments