Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 183 additions & 4 deletions ggml/src/iqk/iqk_gemm_kquants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,10 +810,11 @@ static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const Data
auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d)));
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16));
_mm256_storeu_ps(d8 + 8*iy, dy);
auto m4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4)));
auto m4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4)));
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(m4_2, m4_1), 16));
accd[iy] = _mm256_fmadd_ps(my, mins, accd[iy]);
auto m4_1 = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4)));
auto m4_2 = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4)));
auto myi = MM256_SET_M128I(m4_2, m4_1);
auto my = _mm256_mul_ps(dy, _mm256_cvtepi32_ps(myi));
accd[iy] = _mm256_fmadd_ps(my, mins, accd[iy]);
}

auto all_scales = _mm256_mul_ps(_mm256_set1_ps(deq.d), _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)utmp))));
Expand Down Expand Up @@ -2017,6 +2018,91 @@ typedef struct {
int8_t qs[8*QK8_1];
} block_q8_1_r8;

void iqk_convert_q2_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_q2_K * x8[8];

block_q8_k_r8 * y = (block_q8_k_r8 *)vy;

float f_values[QK_K];
uint32_t block[8];

__m256i xv[4];

auto ml = _mm256_set1_epi8(0x03);
auto sign_bit = _mm256_set1_ps(-0.0f);
auto perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q2_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
auto vd = _mm256_set1_ps(GGML_FP16_TO_FP32(x8[k][i].d));
auto vm = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x8[k][i].dmin)), _mm256_set1_ps(-1.f));
auto block_max = _mm256_setzero_ps();
for (int i128 = 0; i128 < 2; ++i128) {
auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+i128);
xv[0] = _mm256_and_si256(bits, ml);
xv[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), ml);
xv[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), ml);
xv[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), ml);
for (int l = 0; l < 4; ++l) {
auto q1 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(xv[l]));
auto q2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xv[l], 1));
q1 = _mm256_mullo_epi16(q1, _mm256_set1_epi16(x8[k][i].scales[8*i128 + 2*l + 0] & 0xf));
q2 = _mm256_mullo_epi16(q2, _mm256_set1_epi16(x8[k][i].scales[8*i128 + 2*l + 1] & 0xf));
auto m1 = _mm256_mul_ps(vm, _mm256_set1_ps(x8[k][i].scales[8*i128 + 2*l + 0] >> 4));
auto m2 = _mm256_mul_ps(vm, _mm256_set1_ps(x8[k][i].scales[8*i128 + 2*l + 1] >> 4));
auto v0 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q1))), vd, m1);
auto v1 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q1, 1))), vd, m1);
auto v2 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q2))), vd, m2);
auto v3 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q2, 1))), vd, m2);
auto max = _mm256_max_ps(_mm256_max_ps(_mm256_andnot_ps(sign_bit, v0), _mm256_andnot_ps(sign_bit, v1)),
_mm256_max_ps(_mm256_andnot_ps(sign_bit, v2), _mm256_andnot_ps(sign_bit, v3)));
block_max = _mm256_max_ps(block_max, max);
_mm256_storeu_ps(f_values + 128*i128 + 32*l + 0, v0);
_mm256_storeu_ps(f_values + 128*i128 + 32*l + 8, v1);
_mm256_storeu_ps(f_values + 128*i128 + 32*l + 16, v2);
_mm256_storeu_ps(f_values + 128*i128 + 32*l + 24, v3);
}
}
auto max4 = _mm_max_ps(_mm256_extractf128_ps(block_max, 1), _mm256_castps256_ps128(block_max));
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
float d = _mm_cvtss_f32(max4/127.f);
auto id = _mm256_set1_ps(d != 0.0f ? 1/d : 0.0f);
y[i].d[k] = GGML_FP32_TO_FP16(d);
for (int ib32 = 0; ib32 < 8; ++ib32) {
auto v0 = _mm256_loadu_ps(f_values + 32*ib32 + 0);
auto v1 = _mm256_loadu_ps(f_values + 32*ib32 + 8);
auto v2 = _mm256_loadu_ps(f_values + 32*ib32 + 16);
auto v3 = _mm256_loadu_ps(f_values + 32*ib32 + 24);
auto i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v0, id), _MM_ROUND_NEAREST));
auto i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v1, id), _MM_ROUND_NEAREST));
auto i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v2, id), _MM_ROUND_NEAREST));
auto i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v3, id), _MM_ROUND_NEAREST));
i0 = _mm256_packs_epi32(i0, i1);
i2 = _mm256_packs_epi32(i2, i3);
i0 = _mm256_packs_epi16(i0, i2);
i0 = _mm256_permutevar8x32_epi32(i0, perm);

_mm256_storeu_si256((__m256i *)block, i0);
auto q8 = (uint32_t *)y[i].qs + 64*ib32;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
}
}
}
y += nb;
}
}

void iqk_convert_q4_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
Expand Down Expand Up @@ -2429,6 +2515,97 @@ void iqk_convert_q3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
}
}

inline float convert_to_q8_k_r8(int k, float d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
auto max_i16 = _mm256_setzero_si256();
__m256i qs[16];
for (int ib32 = 0; ib32 < 8; ++ib32) {
qs[2*ib32+0] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
qs[2*ib32+1] = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
qs[2*ib32+0] = _mm256_mullo_epi16(qs[2*ib32+0], _mm256_set1_epi16(scales[2*ib32+0]));
qs[2*ib32+1] = _mm256_mullo_epi16(qs[2*ib32+1], _mm256_set1_epi16(scales[2*ib32+1]));
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+0], qs[2*ib32+0]));
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+1], qs[2*ib32+1]));
}
auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1)));
auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
auto max4 = _mm_cvtepi32_ps(imax4);
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
bool needs_scaling = true;
float dnew = _mm_cvtss_f32(max4) * d0;
if (dnew < 1.f) {
dnew = 1.f; needs_scaling = false;
}
auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
for (int ib32 = 0; ib32 < 8; ++ib32) {
if (needs_scaling) {
auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+0]));
auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+0], 1));
auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+1]));
auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+1], 1));
i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
i0 = _mm256_packs_epi32(i0, i1);
i2 = _mm256_packs_epi32(i2, i3);
i0 = _mm256_packs_epi16(i0, i2);
i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
_mm256_storeu_si256((__m256i *)block, i0);
} else {
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
auto i0 = _mm256_packs_epi16(qs[2*ib32+0], qs[2*ib32+1]);
auto i0_l = _mm256_castsi256_si128(i0);
auto i0_h = _mm256_extracti128_si256(i0, 1);
_mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h));
_mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h));
}
auto qs = (uint32_t *)q8_k + 64*ib32;
for (int l = 0; l < 8; ++l) {
qs[8*l + k] = block[l];
}
}
return dnew;
}

// TODO: move this to iqk_gemm_iquants
void iqk_convert_iq4_xs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_iq4_xs * x8[8];

block_q8_k_r8 * y = (block_q8_k_r8 *)vy;

auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
auto values = MM256_SET_M128I(values128, values128);

int16_t ls[16];
float dnew[8];
uint32_t block[8];
__m256i xv[8];

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq4_xs *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = GGML_FP16_TO_FP32(x8[k][i].d);
for (int ib32 = 0; ib32 < 8; ++ib32) {
ls[2*ib32+0] = ls[2*ib32+1] = (((x8[k][i].scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((x8[k][i].scales_h >> 2*ib32) & 3) << 4)) - 32;
auto bits = _mm_loadu_si128((const __m128i *)x8[k][i].qs + ib32);
xv[ib32] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(bits, 4), bits), _mm256_set1_epi8(0xf));
xv[ib32] = _mm256_shuffle_epi8(values, xv[ib32]);
}
dnew[k] = d * convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs);
}
_mm_storeu_si128((__m128i *)y[i].d, _mm256_cvtps_ph(_mm256_loadu_ps(dnew), _MM_ROUND_NEAREST));
}
y += nb;
}
}


} // namespace

Expand Down Expand Up @@ -2516,10 +2693,12 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_

bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
switch (ggml_type(type)) {
case GGML_TYPE_Q2_K: iqk_convert_q2_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q6_K: iqk_convert_q6_k_q8_0_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_XS: iqk_convert_iq4_xs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;
Expand Down
Loading