Skip to content

Commit 7a882f0

Browse files
ikawrakowIwan Kawrakow
andauthored
Perhaps a slightly better version for IQ2_XXS, IQ3_XXS, IQ3_S GEMV (#524)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent b57bd86 commit 7a882f0

File tree

1 file changed

+105
-59
lines changed

1 file changed

+105
-59
lines changed

ggml/src/iqk/iqk_gemm_iquants.cpp

Lines changed: 105 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -145,35 +145,6 @@ struct SignHelper {
145145
const __m256i mone = _mm256_set1_epi8(1);
146146
};
147147

148-
// for (int i = 0; i < nb; ++i) {
149-
//
150-
// __m256i sumi[nrc_y], all_scales;
151-
// //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
152-
// __m256i mins;
153-
// float dmin = deq.new_block(i, &all_scales, mins);
154-
// for (int iy = 0; iy < nrc_y; ++iy) {
155-
// auto bsums = q8.load_bsums(iy, i);
156-
// auto prod = _mm256_madd_epi16(mins, bsums);
157-
// accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
158-
// }
159-
//
160-
// for (int j = 0; j < QK_K/128; ++j) {
161-
// deq.prepare(i, j);
162-
// set_scales_8(&all_scales, j, scales);
163-
// //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
164-
// multiply_add(deq.bits, scales, j, i, q8, sumi);
165-
// }
166-
// for (int iy = 0; iy < nrc_y; ++iy) {
167-
// const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
168-
// accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
169-
// }
170-
// }
171-
//
172-
// for (int iy = 0; iy < nrc_y; ++iy) {
173-
// info.store(ix, iy, hsum_float_8(accd[iy]));
174-
// }
175-
// }
176-
177148
struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
178149
DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
179150

@@ -221,7 +192,7 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
221192
}
222193

223194
IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const {
224-
#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
195+
#if defined z_HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
225196
esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0);
226197
esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2);
227198
#else
@@ -246,7 +217,11 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
246217
}
247218
inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
248219
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
249-
Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
220+
Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
221+
make4(data.val, bits.values, q8_quants);
222+
}
223+
inline void prepare(int i, int j, __m256i * q8_quants) {
224+
Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
250225
make4(data.val, bits.values, q8_quants);
251226
}
252227

@@ -526,6 +501,13 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
526501
sign_2_values(signs+0, q8_quants+0);
527502
sign_2_values(signs+4, q8_quants+2);
528503
}
504+
inline void prepare(int i, int j, __m256i * q8_quants) {
505+
auto qs = x[i].qs + 32*j;
506+
const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j;
507+
make4_unsigned(qs, bits.values);
508+
sign_2_values(signs+0, q8_quants+0);
509+
sign_2_values(signs+4, q8_quants+2);
510+
}
529511

530512
constexpr static int minv = 64;
531513

@@ -625,6 +607,10 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
625607
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
626608
sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants);
627609
}
610+
inline void prepare(int i, int j, __m256i * q8_quants) {
611+
prepare_unsigned(i, j);
612+
sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants);
613+
}
628614

629615
inline void prepare_unsigned(int i, int j) {
630616
auto qs = x[i].qs + 32*j;
@@ -787,15 +773,69 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
787773
}
788774
}
789775

790-
template <typename Dequantizer, int nrc_y>
776+
template <int n_sum>
777+
inline __m256i compute_dot_4(const __m256i * x, const __m256i * y) {
778+
#ifdef HAVE_FANCY_SIMD
779+
auto sumi0 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[0], y[0]);
780+
auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[1], y[1]);
781+
auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[2], y[2]);
782+
auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[3], y[3]);
783+
sumi0 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
784+
sumi2 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
785+
return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
786+
#else
787+
auto m1 = _mm256_set1_epi16(1);
788+
if constexpr (n_sum == 2) {
789+
auto sumi0 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[0], y[0]));
790+
auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[1], y[1]));
791+
auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[2], y[2]));
792+
auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[3], y[3]));
793+
sumi0 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
794+
sumi2 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
795+
return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
796+
}
797+
else {
798+
auto sumi0 = _mm256_maddubs_epi16(x[0], y[0]);
799+
auto sumi1 = _mm256_maddubs_epi16(x[1], y[1]);
800+
auto sumi2 = _mm256_maddubs_epi16(x[2], y[2]);
801+
auto sumi3 = _mm256_maddubs_epi16(x[3], y[3]);
802+
if constexpr (n_sum == 4) {
803+
sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
804+
sumi2 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
805+
sumi0 = _mm256_madd_epi16(m1, sumi0);
806+
sumi2 = _mm256_madd_epi16(m1, sumi2);
807+
return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
808+
}
809+
else {
810+
auto sumi0 = _mm256_maddubs_epi16(x[0], y[0]);
811+
auto sumi1 = _mm256_maddubs_epi16(x[1], y[1]);
812+
auto sumi2 = _mm256_maddubs_epi16(x[2], y[2]);
813+
auto sumi3 = _mm256_maddubs_epi16(x[3], y[3]);
814+
sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
815+
sumi2 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
816+
sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
817+
return _mm256_madd_epi16(m1, sumi0);
818+
}
819+
}
820+
#endif
821+
}
822+
823+
template <typename Dequantizer, int nrc_y, int n_sum = 2>
791824
static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
792825
static_assert(Dequantizer::num_blocks == 8);
826+
static_assert(n_sum == 2 || n_sum == 4 || n_sum == 8);
827+
#ifdef HAVE_FANCY_SIMD
828+
constexpr bool use_1_row = nrc_y == 1;
829+
#else
830+
constexpr bool use_1_row = nrc_y == 1 && !std::is_same_v<Dequantizer, DequantizerIQ2XXS>;
831+
#endif
832+
793833
const int nb = n / QK_K;
794834
Q8<nrc_y, block_q8_2_x4> q8(info);
795835
Dequantizer deq(vx, bx);
796836
__m256 scales[3];
797837
__m256 accd[nrc_y];
798-
__m256i sumi[4];
838+
__m256i vy[4];
799839

800840
for (int ix = 0; ix < nrc_x; ++ix) {
801841

@@ -806,35 +846,33 @@ static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const Data
806846
for (int i = 0; i < nb; ++i) {
807847

808848
deq.new_block_f(i, scales);
809-
for (int iy = 0; iy < nrc_y; ++iy) {
810-
auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4)));
811-
auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4)));
812-
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
813-
accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]);
849+
if constexpr (!use_1_row) {
850+
for (int iy = 0; iy < nrc_y; ++iy) {
851+
auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4)));
852+
auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4)));
853+
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
854+
accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]);
855+
}
814856
}
815857

816858
for (int j = 0; j < QK_K/128; ++j) {
817-
deq.prepare(i, j);
818-
auto& values = deq.bits.values;
819-
for (int iy = 0; iy < nrc_y; ++iy) {
820-
auto qs = q8.y[iy][2*i+j].qs;
821-
#ifdef HAVE_FANCY_SIMD
822-
sumi[0] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[0], _mm256_loadu_si256((const __m256i*)qs+0));
823-
sumi[1] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[1], _mm256_loadu_si256((const __m256i*)qs+1));
824-
sumi[2] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[2], _mm256_loadu_si256((const __m256i*)qs+2));
825-
sumi[3] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[3], _mm256_loadu_si256((const __m256i*)qs+3));
826-
#else
827-
sumi[0] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[0], _mm256_loadu_si256((const __m256i*)qs+0)));
828-
sumi[1] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[1], _mm256_loadu_si256((const __m256i*)qs+1)));
829-
sumi[2] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[2], _mm256_loadu_si256((const __m256i*)qs+2)));
830-
sumi[3] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[3], _mm256_loadu_si256((const __m256i*)qs+3)));
831-
#endif
832-
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1]));
833-
sumi[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3]));
834-
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2]));
835-
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16));
859+
if constexpr (use_1_row) {
860+
for (int k = 0; k < 4; ++k) vy[k] = _mm256_loadu_si256((const __m256i*)q8.y[0][2*i+j].qs+k);
861+
deq.prepare(i, j, vy);
862+
auto sumi = compute_dot_4<2*n_sum>(deq.bits.values, vy);
863+
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[0][2*i+j].d)), 16));
836864
auto dy = _mm256_set_m128(d4, d4);
837-
accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi[0]), accd[iy]);
865+
accd[0] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi), accd[0]);
866+
} else {
867+
deq.prepare(i, j);
868+
for (int iy = 0; iy < nrc_y; ++iy) {
869+
auto qs = q8.y[iy][2*i+j].qs;
870+
for (int k = 0; k < 4; ++k) vy[k] = _mm256_loadu_si256((const __m256i*)qs+k);
871+
auto sumi = compute_dot_4<n_sum>(deq.bits.values, vy);
872+
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16));
873+
auto dy = _mm256_set_m128(d4, d4);
874+
accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi), accd[iy]);
875+
}
838876
}
839877
}
840878
}
@@ -1934,7 +1972,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
19341972

19351973
if (ggml_type(typeA) == GGML_TYPE_IQ3_S) {
19361974
if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
1937-
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels);
1975+
//IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels);
1976+
kernels[0] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 1, 8>;
1977+
kernels[1] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 2, 8>;
1978+
kernels[2] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 3, 8>;
1979+
kernels[3] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 4, 8>;
1980+
kernels[4] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 5, 8>;
1981+
kernels[5] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 6, 8>;
1982+
kernels[6] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 7, 8>;
1983+
kernels[7] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 8, 8>;
19381984
func16 = nullptr;
19391985
return true;
19401986
}

0 commit comments

Comments
 (0)