Skip to content

Commit 6c0b796

Browse files
author
Iwan Kawrakow
committed
WIP
1 parent 8fcede9 commit 6c0b796

File tree

1 file changed

+45
-24
lines changed

1 file changed

+45
-24
lines changed

ggml/src/iqk/iqk_gemm_ktquants.cpp

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ struct Trellis3 {
171171
}
172172
}
173173
IQK_ALWAYS_INLINE inline void next_128(const uint32_t * val, __m256i * result) const {
174+
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
175+
// On AVX2 we don't have enough vector egisters to do this
174176
__m256i aux[16];
175177
auto perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
176178
for (int k = 0; k < 4; ++k) {
@@ -203,8 +205,13 @@ struct Trellis3 {
203205
result[k] = _mm256_sign_epi8(result[k], result[k]);
204206
}
205207
}
208+
#else
209+
for (int k = 0; k < 4; ++k) result[k] = next32(val + 8*k);
210+
#endif
206211
}
207212
IQK_ALWAYS_INLINE inline void next_128(const uint16_t * val, uint32_t v0, __m256i * result) const {
213+
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
214+
// On AVX2 we don't have enough vector egisters to do this
208215
__m256i aux[16];
209216
for (int k = 0; k < 4; ++k) {
210217
auto v128 = _mm_add_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(val + 4*k))), _mm_set1_epi32(v0));
@@ -236,15 +243,9 @@ struct Trellis3 {
236243
result[k] = _mm256_sign_epi8(result[k], result[k]);
237244
}
238245
}
239-
//for (int k = 0; k < 4; ++k) {
240-
// for (int i = 0; i < 4; ++i) {
241-
// aux[i] = _mm256_and_si256(aux[4*k+i], _mm256_set1_epi32(0x3f3f3f3f));
242-
// aux[i] = _mm256_dpbusd_epi32(offset, _mm256_set1_epi32(0x01010101), aux[i]);
243-
// }
244-
// aux[0] = _mm256_packs_epi32(aux[0], aux[1]);
245-
// aux[2] = _mm256_packs_epi32(aux[2], aux[3]);
246-
// result[k] = _mm256_permutevar8x32_epi32(_mm256_packs_epi16(aux[0], aux[2]), shuffle);
247-
//}
246+
#else
247+
for (int k = 0; k < 4; ++k) result[k] = next32(val + 4*k, v0);
248+
#endif
248249
}
249250
inline __m256i next32(const uint16_t * val, uint32_t v0) const {
250251
const __m256i offset = _mm256_set1_epi32(-126);
@@ -521,7 +522,6 @@ void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
521522
scales[1] = _mm256_set_m128(scales_h, scales_h);
522523
for (int i128 = 0; i128 < 2; ++i128) {
523524
trellis.next_128(ql + 16*i128, 4096, xv);
524-
//for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096);
525525
for (int iy = 0; iy < nrc_y; ++iy) {
526526
const block_q8_2_x4& yb = y[iy][2*i+i128];
527527
auto dy4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)yb.d)), 16));
@@ -671,11 +671,6 @@ void mul_mat_iq3_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
671671
sv[k] = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sign_bits, mask), mask), mask);
672672
sign_bits = _mm256_srli_epi16(sign_bits, 1);
673673
}
674-
//for (int k = 0; k < 4; ++k) {
675-
// xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096);
676-
// sv[k] = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sign_bits, mask), mask), _mm256_set1_epi8(1));
677-
// mask = _mm256_slli_epi16(mask, 1);
678-
//}
679674
for (int iy = 0; iy < nrc_y; ++iy) {
680675
const block_q8_2_x4& yb = y[iy][2*i+i128];
681676
auto dy4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)yb.d)), 16));
@@ -952,7 +947,9 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
952947
}
953948
};
954949

955-
//auto m126 = _mm256_set1_ps(-126.f);
950+
//auto shift1 = _mm256_setr_epi32(8, 8, 8, 8, 20, 20, 20, 20);
951+
//auto shift2 = _mm256_setr_epi32(12, 9, 6, 3, 12, 9, 6, 3);
952+
//__m256i values[8];
956953

957954
for (int ix = 0; ix < nrc_x; ++ix) {
958955
const float * dptr = (const float *)((const char*)vx + ix*bx);
@@ -975,18 +972,42 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
975972
scales[1] = _mm256_set_m128(scales_h, scales_h);
976973
o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096));
977974
for (int ib = 0; ib < 4; ++ib) {
978-
for (int j = 0; j < 4; ++j) {
979-
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
980-
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
981-
values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
982-
values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
983-
values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
984-
values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
975+
// Somehow this is slower.
976+
//auto idxl = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(ql + 8*ib)));
977+
//auto idxh = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(ql + 8*ib + 32)));
978+
//auto vh = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qh + 8*ib)));
979+
//idxl = _mm256_or_si256(idxl, _mm256_and_si256(_mm256_slli_epi32(vh, 8), _mm256_set1_epi32(0xf00)));
980+
//idxh = _mm256_or_si256(idxh, _mm256_and_si256(_mm256_slli_epi32(vh, 4), _mm256_set1_epi32(0xf00)));
981+
//auto shl = _mm256_sllv_epi32(_mm256_srlv_epi32(_mm256_set1_epi32(shb[ib+0]), shift1), shift2);
982+
//auto shh = _mm256_sllv_epi32(_mm256_srlv_epi32(_mm256_set1_epi32(shb[ib+4]), shift1), shift2);
983+
//idxl = _mm256_or_si256(idxl, _mm256_and_si256(shl, _mm256_set1_epi32(0x7000)));
984+
//idxh = _mm256_or_si256(idxh, _mm256_and_si256(shh, _mm256_set1_epi32(0x7000)));
985+
//values[ib+0] = _mm256_add_epi32(idxl, _mm256_set1_epi32(o_helper.val[ib+0]));
986+
//values[ib+4] = _mm256_add_epi32(idxh, _mm256_set1_epi32(o_helper.val[ib+4]));
987+
for (int j = 0; j < 2; ++j) {
988+
const uint32_t sh1 = shb[ib+0] >> (8 + 12*j);
989+
const uint32_t sh2 = shb[ib+4] >> (8 + 12*j);
990+
//values[8*ib+4*j+ 0] = ql[8*ib+4*j+ 0] + ((qh[8*ib+4*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
991+
//values[8*ib+4*j+ 1] = ql[8*ib+4*j+ 1] + ((qh[8*ib+4*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
992+
//values[8*ib+4*j+ 2] = ql[8*ib+4*j+ 2] + ((qh[8*ib+4*j+2] << 8) & 0xf00) + ((sh1 & 448) << 6) + o_helper.val[ib+0];
993+
//values[8*ib+4*j+ 3] = ql[8*ib+4*j+ 3] + ((qh[8*ib+4*j+3] << 8) & 0xf00) + ((sh1 & 3584) << 3) + o_helper.val[ib+0];
994+
//values[8*ib+4*j+32] = ql[8*ib+4*j+32] + ((qh[8*ib+4*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
995+
//values[8*ib+4*j+33] = ql[8*ib+4*j+33] + ((qh[8*ib+4*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
996+
//values[8*ib+4*j+34] = ql[8*ib+4*j+34] + ((qh[8*ib+4*j+2] << 4) & 0xf00) + ((sh2 & 448) << 6) + o_helper.val[ib+4];
997+
//values[8*ib+4*j+35] = ql[8*ib+4*j+35] + ((qh[8*ib+4*j+3] << 4) & 0xf00) + ((sh2 & 3584) << 3) + o_helper.val[ib+4];
998+
values[8*ib+4*j+ 0] = ql[8*ib+4*j+ 0] + ((qh[8*ib+4*j+0] << 8) & 0xf00) + ((sh1 << 12) & 0x7000) + o_helper.val[ib+0];
999+
values[8*ib+4*j+ 1] = ql[8*ib+4*j+ 1] + ((qh[8*ib+4*j+1] << 8) & 0xf00) + ((sh1 << 9) & 0x7000) + o_helper.val[ib+0];
1000+
values[8*ib+4*j+ 2] = ql[8*ib+4*j+ 2] + ((qh[8*ib+4*j+2] << 8) & 0xf00) + ((sh1 << 6) & 0x7000) + o_helper.val[ib+0];
1001+
values[8*ib+4*j+ 3] = ql[8*ib+4*j+ 3] + ((qh[8*ib+4*j+3] << 8) & 0xf00) + ((sh1 << 3) & 0x7000) + o_helper.val[ib+0];
1002+
values[8*ib+4*j+32] = ql[8*ib+4*j+32] + ((qh[8*ib+4*j+0] << 4) & 0xf00) + ((sh2 << 12) & 0x7000) + o_helper.val[ib+4];
1003+
values[8*ib+4*j+33] = ql[8*ib+4*j+33] + ((qh[8*ib+4*j+1] << 4) & 0xf00) + ((sh2 << 9) & 0x7000) + o_helper.val[ib+4];
1004+
values[8*ib+4*j+34] = ql[8*ib+4*j+34] + ((qh[8*ib+4*j+2] << 4) & 0xf00) + ((sh2 << 6) & 0x7000) + o_helper.val[ib+4];
1005+
values[8*ib+4*j+35] = ql[8*ib+4*j+35] + ((qh[8*ib+4*j+3] << 4) & 0xf00) + ((sh2 << 3) & 0x7000) + o_helper.val[ib+4];
9851006
}
9861007
}
9871008
for (int i128 = 0; i128 < 2; ++i128) {
988-
//for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k);
9891009
trellis.next_128(values + 32*i128, xv);
1010+
//trellis.next_128(values + 4*i128, xv);
9901011
for (int iy = 0; iy < nrc_y; ++iy) {
9911012
const block_q8_2_x4& yb = y[iy][2*i+i128];
9921013
auto dy4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)yb.d)), 16));

0 commit comments

Comments
 (0)