Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 118 additions & 35 deletions ggml/src/iqk/iqk_gemm_ktquants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,89 @@ struct Trellis3 {
return _mm256_permutevar8x32_epi32(aux[0], shuffle);
}
}
IQK_ALWAYS_INLINE inline void next_128(const uint32_t * val, __m256i * result) const {
// Even though we only have 16 vector registers nn AVX2, this is still faster
__m256i aux[16];
auto perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
for (int k = 0; k < 4; ++k) {
auto v = _mm256_loadu_si256((const __m256i *)val + k);
v = _mm256_permutevar8x32_epi32(v, perm);
aux[4*k+0] = _mm256_shuffle_epi32(v, 0x00);
aux[4*k+1] = _mm256_shuffle_epi32(v, 0x55);
aux[4*k+2] = _mm256_shuffle_epi32(v, 0xaa);
aux[4*k+3] = _mm256_shuffle_epi32(v, 0xff);
}
for (int i = 0; i < 16; ++i) {
aux[i] = _mm256_mullo_epi32(aux[i], mka);
}
auto mask = _mm256_set1_epi32(0x3f3f3f3f);
for (int i = 0; i < 16; ++i) {
aux[i] = _mm256_and_si256(aux[i], mask);
}
auto offset = _mm256_set1_epi32(-126);
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
auto m1 = _mm256_set1_epi32(0x01010101);
#endif
for (int i = 0; i < 16; ++i) {
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
aux[i] = _mm256_dpbusd_epi32(offset, aux[i], m1);
#else
auto dot = _mm256_maddubs_epi16(aux[i], _mm256_set1_epi32(0x01010101));
aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1)));
#endif
}
for (int k = 0; k < 4; ++k) {
auto v1 = _mm256_packs_epi32(aux[4*k+0], aux[4*k+1]);
auto v2 = _mm256_packs_epi32(aux[4*k+2], aux[4*k+3]);
result[k] = _mm256_permutevar8x32_epi32(_mm256_packs_epi16(v1, v2), shuffle);
}
if constexpr (is_abs) {
for (int k = 0; k < 4; ++k) {
result[k] = _mm256_sign_epi8(result[k], result[k]);
}
}
}
IQK_ALWAYS_INLINE inline void next_128(const uint16_t * val, uint32_t v0, __m256i * result) const {
// Even though we only have 16 vector registers nn AVX2, this is still faster
__m256i aux[16];
for (int k = 0; k < 4; ++k) {
auto v128 = _mm_add_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(val + 4*k))), _mm_set1_epi32(v0));
auto v = MM256_SET_M128I(v128, v128);
aux[4*k+0] = _mm256_shuffle_epi32(v, 0x00);
aux[4*k+1] = _mm256_shuffle_epi32(v, 0x55);
aux[4*k+2] = _mm256_shuffle_epi32(v, 0xaa);
aux[4*k+3] = _mm256_shuffle_epi32(v, 0xff);
}
for (int i = 0; i < 16; ++i) {
aux[i] = _mm256_mullo_epi32(aux[i], mka);
}
auto mask = _mm256_set1_epi32(0x3f3f3f3f);
for (int i = 0; i < 16; ++i) {
aux[i] = _mm256_and_si256(aux[i], mask);
}
auto offset = _mm256_set1_epi32(-126);
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
auto m1 = _mm256_set1_epi32(0x01010101);
#endif
for (int i = 0; i < 16; ++i) {
#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__)
aux[i] = _mm256_dpbusd_epi32(offset, aux[i], m1);
#else
auto dot = _mm256_maddubs_epi16(aux[i], _mm256_set1_epi32(0x01010101));
aux[i] = _mm256_add_epi32(offset, _mm256_madd_epi16(dot, _mm256_set1_epi16(1)));
#endif
}
for (int k = 0; k < 4; ++k) {
auto v1 = _mm256_packs_epi32(aux[4*k+0], aux[4*k+1]);
auto v2 = _mm256_packs_epi32(aux[4*k+2], aux[4*k+3]);
result[k] = _mm256_permutevar8x32_epi32(_mm256_packs_epi16(v1, v2), shuffle);
}
if constexpr (is_abs) {
for (int k = 0; k < 4; ++k) {
result[k] = _mm256_sign_epi8(result[k], result[k]);
}
}
}
inline __m256i next32(const uint16_t * val, uint32_t v0) const {
const __m256i offset = _mm256_set1_epi32(-126);
__m256i aux[4];
Expand Down Expand Up @@ -385,7 +468,7 @@ void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
assert(n%QK_K == 0);
const int nb = n/QK_K;

Trellis3<true> trellis;
Trellis3<true, false> trellis;

auto shifts = _mm_set_epi32(0, 0, 4, 0);
auto values = _mm_loadu_si128((const __m128i *)iq4k_values);
Expand Down Expand Up @@ -425,8 +508,6 @@ void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
}
};

//auto m126 = _mm256_set1_ps(-126.f);

for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
auto d = _mm256_set1_ps(dptr[0] * 1.05f);
Expand All @@ -446,17 +527,13 @@ void mul_mat_iq2_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
scales[0] = _mm256_set_m128(scales_l, scales_l);
scales[1] = _mm256_set_m128(scales_h, scales_h);
for (int i128 = 0; i128 < 2; ++i128) {
//for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096);
trellis.next_128(ql + 16*i128, 4096, xv);
for (int iy = 0; iy < nrc_y; ++iy) {
const block_q8_2_x4& yb = y[iy][2*i+i128];
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16));
dy = _mm256_mul_ps(scales[i128], dy);
auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy));
//auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
auto dy4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)yb.d)), 16));
auto dy8 = _mm256_mul_ps(scales[i128], _mm256_set_m128(dy4, dy4));
compute_dot(yb.qs);
accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]);
//accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]);
accd[iy] = _mm256_fmadd_ps(dy8, sum_4(), accd[iy]);
}
}
}
Expand Down Expand Up @@ -595,18 +672,17 @@ void mul_mat_iq3_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
scales[1] = _mm256_set_m128(scales_h, scales_h);
auto mask = _mm256_set1_epi8(1);
for (int i128 = 0; i128 < 2; ++i128) {
trellis.next_128(ql + 16*i128, 4096, xv);
for (int k = 0; k < 4; ++k) {
xv[k] = trellis.next32(ql + 16*i128 + 4*k, 4096);
sv[k] = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sign_bits, mask), mask), _mm256_set1_epi8(1));
mask = _mm256_slli_epi16(mask, 1);
sv[k] = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sign_bits, mask), mask), mask);
sign_bits = _mm256_srli_epi16(sign_bits, 1);
}
for (int iy = 0; iy < nrc_y; ++iy) {
const block_q8_2_x4& yb = y[iy][2*i+i128];
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16));
dy = _mm256_mul_ps(scales[i128], dy);
auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy));
auto dy4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)yb.d)), 16));
auto dy8 = _mm256_mul_ps(scales[i128], _mm256_set_m128(dy4, dy4));
compute_dot(yb.qs);
accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]);
accd[iy] = _mm256_fmadd_ps(dy8, sum_4(), accd[iy]);
}
}
}
Expand Down Expand Up @@ -877,8 +953,6 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
}
};

//auto m126 = _mm256_set1_ps(-126.f);

for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
auto d = _mm256_set1_ps(dptr[0]);
Expand All @@ -900,27 +974,27 @@ void mul_mat_iq4_kt_q8_2_x4_T(int n, const void * vx, size_t bx, const DataInfo&
scales[1] = _mm256_set_m128(scales_h, scales_h);
o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096));
for (int ib = 0; ib < 4; ++ib) {
for (int j = 0; j < 4; ++j) {
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
for (int j = 0; j < 2; ++j) {
const uint32_t sh1 = shb[ib+0] >> (8 + 12*j);
const uint32_t sh2 = shb[ib+4] >> (8 + 12*j);
values[8*ib+4*j+ 0] = ql[8*ib+4*j+ 0] + ((qh[8*ib+4*j+0] << 8) & 0xf00) + ((sh1 << 12) & 0x7000) + o_helper.val[ib+0];
values[8*ib+4*j+ 1] = ql[8*ib+4*j+ 1] + ((qh[8*ib+4*j+1] << 8) & 0xf00) + ((sh1 << 9) & 0x7000) + o_helper.val[ib+0];
values[8*ib+4*j+ 2] = ql[8*ib+4*j+ 2] + ((qh[8*ib+4*j+2] << 8) & 0xf00) + ((sh1 << 6) & 0x7000) + o_helper.val[ib+0];
values[8*ib+4*j+ 3] = ql[8*ib+4*j+ 3] + ((qh[8*ib+4*j+3] << 8) & 0xf00) + ((sh1 << 3) & 0x7000) + o_helper.val[ib+0];
values[8*ib+4*j+32] = ql[8*ib+4*j+32] + ((qh[8*ib+4*j+0] << 4) & 0xf00) + ((sh2 << 12) & 0x7000) + o_helper.val[ib+4];
values[8*ib+4*j+33] = ql[8*ib+4*j+33] + ((qh[8*ib+4*j+1] << 4) & 0xf00) + ((sh2 << 9) & 0x7000) + o_helper.val[ib+4];
values[8*ib+4*j+34] = ql[8*ib+4*j+34] + ((qh[8*ib+4*j+2] << 4) & 0xf00) + ((sh2 << 6) & 0x7000) + o_helper.val[ib+4];
values[8*ib+4*j+35] = ql[8*ib+4*j+35] + ((qh[8*ib+4*j+3] << 4) & 0xf00) + ((sh2 << 3) & 0x7000) + o_helper.val[ib+4];
}
}
for (int i128 = 0; i128 < 2; ++i128) {
//for (int k = 0; k < 4; ++k) xv[k] = trellis.next32<true>(values + 32*i128 + 8*k);
for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k);
trellis.next_128(values + 32*i128, xv);
for (int iy = 0; iy < nrc_y; ++iy) {
const block_q8_2_x4& yb = y[iy][2*i+i128];
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)yb.d)), 16));
dy = _mm256_mul_ps(scales[i128], dy);
auto d8 = _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy));
//auto m8 = _mm256_set_m128(_mm256_extractf128_ps(dy, 1), _mm256_extractf128_ps(dy, 1));
auto dy4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)yb.d)), 16));
auto dy8 = _mm256_mul_ps(scales[i128], _mm256_set_m128(dy4, dy4));
compute_dot(yb.qs);
accd[iy] = _mm256_fmadd_ps(d8, sum_4(), accd[iy]);
//accd[iy] = _mm256_fmadd_ps(m8, m126, accd[iy]);
accd[iy] = _mm256_fmadd_ps(dy8, sum_4(), accd[iy]);
}
}
}
Expand Down Expand Up @@ -1020,6 +1094,9 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
if (typeA == GGML_TYPE_IQ4_KT) {
if (typeB == GGML_TYPE_Q8_2_X4) {
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_2_x4_T, kernels);
#ifdef HAVE_FANCY_SIMD
func16 = mul_mat_iq4_kt_q8_2_x4_T<16>;
#endif
return true;
}
return false;
Expand All @@ -1028,6 +1105,9 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
if (typeA == GGML_TYPE_IQ2_KT) {
if (typeB == GGML_TYPE_Q8_2_X4) {
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_q8_2_x4_T, kernels);
#ifdef HAVE_FANCY_SIMD
func16 = mul_mat_iq2_kt_q8_2_x4_T<16>;
#endif
return true;
}
return false;
Expand All @@ -1036,6 +1116,9 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
if (typeA == GGML_TYPE_IQ3_KT) {
if (typeB == GGML_TYPE_Q8_2_X4) {
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_q8_2_x4_T, kernels);
#ifdef HAVE_FANCY_SIMD
func16 = mul_mat_iq3_kt_q8_2_x4_T<16>;
#endif
return true;
}
return false;
Expand Down
6 changes: 3 additions & 3 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ struct MulMat {
case GGML_TYPE_Q6_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ4_NL : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q8_0 : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ2_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ3_KT : return nrc_y >= 16 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ4_KT : return nrc_y >= 24 ? GGML_TYPE_Q8_0_R8 : type;
default: break;
}
#else
Expand Down