Skip to content

Commit 5b677c3

Browse files
author
Iwan Kawrakow
committed
Enable next_128() also on AVX2
Despite having just 16 vector registers it is still faster.
1 parent 1287a66 commit 5b677c3

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

ggml/src/iqk/iqk_gemm_ktquants.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,7 @@ struct Trellis3 {
171171
}
172172
}
173173
IQK_ALWAYS_INLINE inline void next_128(const uint32_t * val, __m256i * result) const {
174-
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
175-
// On AVX2 we don't have enough vector registers to do this
174+
// Even though we only have 16 vector registers nn AVX2, this is still faster
176175
__m256i aux[16];
177176
auto perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
178177
for (int k = 0; k < 4; ++k) {
@@ -191,9 +190,16 @@ struct Trellis3 {
191190
aux[i] = _mm256_and_si256(aux[i], mask);
192191
}
193192
auto offset = _mm256_set1_epi32(-126);
193+
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
194194
auto m1 = _mm256_set1_epi32(0x01010101);
195+
#endif
195196
for (int i = 0; i < 16; ++i) {
197+
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
196198
aux[i] = _mm256_dpbusd_epi32(offset, aux[i], m1);
199+
#else
200+
auto dot = _mm256_maddubs_epi16(aux[i], _mm256_set1_epi32(0x01010101));
201+
aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1)));
202+
#endif
197203
}
198204
for (int k = 0; k < 4; ++k) {
199205
auto v1 = _mm256_packs_epi32(aux[4*k+0], aux[4*k+1]);
@@ -205,13 +211,9 @@ struct Trellis3 {
205211
result[k] = _mm256_sign_epi8(result[k], result[k]);
206212
}
207213
}
208-
#else
209-
for (int k = 0; k < 4; ++k) result[k] = next32(val + 8*k);
210-
#endif
211214
}
212215
IQK_ALWAYS_INLINE inline void next_128(const uint16_t * val, uint32_t v0, __m256i * result) const {
213-
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
214-
// On AVX2 we don't have enough vector registers to do this
216+
// Even though we only have 16 vector registers nn AVX2, this is still faster
215217
__m256i aux[16];
216218
for (int k = 0; k < 4; ++k) {
217219
auto v128 = _mm_add_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(val + 4*k))), _mm_set1_epi32(v0));
@@ -229,9 +231,16 @@ struct Trellis3 {
229231
aux[i] = _mm256_and_si256(aux[i], mask);
230232
}
231233
auto offset = _mm256_set1_epi32(-126);
234+
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
232235
auto m1 = _mm256_set1_epi32(0x01010101);
236+
#endif
233237
for (int i = 0; i < 16; ++i) {
238+
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
234239
aux[i] = _mm256_dpbusd_epi32(offset, aux[i], m1);
240+
#else
241+
auto dot = _mm256_maddubs_epi16(aux[i], _mm256_set1_epi32(0x01010101));
242+
aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1)));
243+
#endif
235244
}
236245
for (int k = 0; k < 4; ++k) {
237246
auto v1 = _mm256_packs_epi32(aux[4*k+0], aux[4*k+1]);
@@ -243,9 +252,6 @@ struct Trellis3 {
243252
result[k] = _mm256_sign_epi8(result[k], result[k]);
244253
}
245254
}
246-
#else
247-
for (int k = 0; k < 4; ++k) result[k] = next32(val + 4*k, v0);
248-
#endif
249255
}
250256
inline __m256i next32(const uint16_t * val, uint32_t v0) const {
251257
const __m256i offset = _mm256_set1_epi32(-126);

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,9 @@ struct MulMat {
264264
case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
265265
case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
266266
case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
267-
case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
268-
case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
269-
case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
267+
case GGML_TYPE_IQ2_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type;
268+
case GGML_TYPE_IQ3_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type;
269+
case GGML_TYPE_IQ4_KT : return nrc_y >= 24 ? GGML_TYPE_Q8_0_R8 : type;
270270
default: break;
271271
}
272272
#else

0 commit comments

Comments
 (0)