Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q4_K,
.from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref,
.vec_dot = ggml_vec_dot_q4_K_q8_K,
#ifdef __AVX2__
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_K,
#endif
.nrows = 1,
.row_meta_size = 0,
},
Expand All @@ -1002,7 +1006,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q5_K,
.from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref,
.vec_dot = ggml_vec_dot_q5_K_q8_K,
#ifdef __AVX2__
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_K,
#endif
.nrows = 1,
.row_meta_size = 0,
},
Expand Down
297 changes: 295 additions & 2 deletions ggml/src/iqk/iqk_gemm_kquants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,147 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf

#endif

// inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
// make_q4_scales(data, utmp);
// const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
// const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);
// accum_mins(mins128, q8, i, c, accd);
// const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
// return MM256_SET_M128I(sc128, sc128);
// }
//
// inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
// d = GGML_FP16_TO_FP32(x[i].d);
// bits.prepare(x[i].qs);
// auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
// scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
// scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
// }


struct Q4Bits_AVX2 {
inline void prepare(const uint8_t * q4, int j) {
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
values[0] = _mm256_and_si256(q4bits, ml);
values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
values[2] = _mm256_and_si256(q4bits, ml);
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
}
__m256i values[4];
const __m256i ml = _mm256_set1_epi8(0xf);
};

struct DequantizerQ4K_AVX2 final : public BaseDequantizer<block_q4_K> {
DequantizerQ4K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs, j);
}
Q4Bits_AVX2 bits;
};

struct DequantizerQ5K_AVX2 final : public BaseDequantizer<block_q5_K> {
DequantizerQ5K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs, j);
hbits = j == 0 ? _mm256_loadu_si256((const __m256i *)x[i].qh) : _mm256_srli_epi16(hbits, 4);
apply_hbits();
}
inline void apply_hbits() {
bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh));
bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
}

const __m256i mh = _mm256_set1_epi8(0x10);
Q4Bits_AVX2 bits;
__m256i hbits;
};

template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;

Q8<nrc_y, block_q8_2_x4> q8(info);

Dequantizer deq(vx, bx);

uint32_t utmp[4];
__m256 accd[nrc_y];
__m256 scales[2];
float d8[8*nrc_y];

for (int ix = 0; ix < nrc_x; ++ix) {

for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();

deq.new_row(ix);

for (int i = 0; i < nb; ++i) {

deq.d = GGML_FP16_TO_FP32(deq.x[i].d);
auto vm = _mm256_cvtph_ps(_mm_set1_epi16(deq.x[i].dmin));
make_q4_scales(deq.x[i].scales, utmp);
auto mins = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(utmp + 2)))));
mins = _mm256_mul_ps(_mm256_set1_ps(-1.f), mins);
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d)));
auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d)));
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16));
_mm256_storeu_ps(d8 + 8*iy, dy);
auto m4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4)));
auto m4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4)));
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(m4_2, m4_1), 16));
accd[iy] = _mm256_fmadd_ps(my, mins, accd[iy]);
}

auto all_scales = _mm256_mul_ps(_mm256_set1_ps(deq.d), _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)utmp))));
scales[0] = _mm256_set_m128(_mm256_castps256_ps128(all_scales), _mm256_castps256_ps128(all_scales));
auto scales_h = _mm256_extractf128_ps(all_scales, 1);
scales[1] = _mm256_set_m128(scales_h, scales_h);

for (int j = 0; j < QK_K/128; ++j) {

deq.prepare(i, j);

for (int iy = 0; iy < nrc_y; ++iy) {
const block_q8_2_x4& y = q8.y[iy][2*i+j];
#ifdef HAVE_FANCY_SIMD
auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0));
auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1));
auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2));
auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3));
sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
#else
auto sumi1 = _mm256_maddubs_epi16(deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0));
auto sumi2 = _mm256_maddubs_epi16(deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1));
auto sumi3 = _mm256_maddubs_epi16(deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2));
auto sumi4 = _mm256_maddubs_epi16(deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3));
sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1);
#endif
auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j);
auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4));
accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]);
}

}

}

for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
}

}
}

template <int nrc_y>
static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Expand Down Expand Up @@ -1702,6 +1843,146 @@ static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const Data
}
}

typedef struct {
ggml_half d[16];
int8_t qs[8*QK8_1];
} block_q8_1_r8;

void iqk_convert_q4_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_q4_K * x8[8];

block_q8_1_r8 * y = (block_q8_1_r8 *)vy;

ggml_half dh[16];
uint16_t all_ls[128];

uint32_t utmp[4];
const uint8_t * u8 = (const uint8_t *)utmp;
uint32_t block[8];

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q4_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
dh[k+0] = x8[k][i].d;
dh[k+8] = x8[k][i].dmin;
make_q4_scales(x8[k][i].scales, utmp);
auto qs = x8[k][i].qs;
for (int ib64 = 0; ib64 < 4; ++ib64) {
all_ls[8*(2*ib64 + 0) + k ] = u8[2*ib64+0];
all_ls[8*(2*ib64 + 1) + k ] = u8[2*ib64+1];
all_ls[8*(2*ib64 + 0) + k + 64] = u8[2*ib64+8];
all_ls[8*(2*ib64 + 1) + k + 64] = u8[2*ib64+9];
auto bits = _mm256_loadu_si256((const __m256i *)qs+ib64);
auto values1 = _mm256_and_si256(bits, _mm256_set1_epi8(0xf));
auto values2 = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf));
_mm256_storeu_si256((__m256i *)block, values1);
auto q8 = (uint32_t *)y[2*ib64+0].qs;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
_mm256_storeu_si256((__m256i *)block, values2);
q8 = (uint32_t *)y[2*ib64+1].qs;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
}
}
auto vd = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+0));
auto vm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+1));
vm = _mm256_mul_ps(_mm256_set1_ps(-1.f), vm);
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d+0, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32 + 8);
iscales32 = _mm256_cvtepi16_epi32(iscales16);
scales = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d+1, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
}
y += QK_K/32;
}
}
}

void iqk_convert_q5_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_q5_K * x8[8];

block_q8_1_r8 * y = (block_q8_1_r8 *)vy;

ggml_half dh[16];
uint16_t all_ls[128];

uint32_t utmp[4];
const uint8_t * u8 = (const uint8_t *)utmp;
uint32_t block[8];

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q5_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
dh[k+0] = x8[k][i].d;
dh[k+8] = x8[k][i].dmin;
make_q4_scales(x8[k][i].scales, utmp);
auto qs = x8[k][i].qs;
auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh);
for (int ib64 = 0; ib64 < 4; ++ib64) {
all_ls[8*(2*ib64 + 0) + k ] = u8[2*ib64+0];
all_ls[8*(2*ib64 + 1) + k ] = u8[2*ib64+1];
all_ls[8*(2*ib64 + 0) + k + 64] = u8[2*ib64+8];
all_ls[8*(2*ib64 + 1) + k + 64] = u8[2*ib64+9];
auto bits = _mm256_loadu_si256((const __m256i *)qs+ib64);
auto values1 = _mm256_and_si256(bits, _mm256_set1_epi8(0xf));
auto values2 = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf));
values1 = _mm256_or_si256(values1, _mm256_and_si256(_mm256_set1_epi8(0x10), _mm256_slli_epi16(hbits, 4)));
values2 = _mm256_or_si256(values2, _mm256_and_si256(_mm256_set1_epi8(0x10), _mm256_slli_epi16(hbits, 3)));
hbits = _mm256_srli_epi16(hbits, 2);
_mm256_storeu_si256((__m256i *)block, values1);
auto q8 = (uint32_t *)y[2*ib64+0].qs;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
_mm256_storeu_si256((__m256i *)block, values2);
q8 = (uint32_t *)y[2*ib64+1].qs;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
}
}
auto vd = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+0));
auto vm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+1));
vm = _mm256_mul_ps(_mm256_set1_ps(-1.f), vm);
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d+0, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32 + 8);
iscales32 = _mm256_cvtepi16_epi32(iscales16);
scales = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d+1, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
}
y += QK_K/32;
}
}
}


} // namespace

bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
Expand All @@ -1710,6 +1991,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
: etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
: etypeA == GGML_TYPE_Q4_K || etypeA == GGML_TYPE_Q5_K ? GGML_TYPE_Q8_2_X4
: GGML_TYPE_Q8_K;

if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
Expand All @@ -1726,10 +2008,12 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
set_functions<DequantizerQ3K>(kernels);
break;
case GGML_TYPE_Q4_K:
set_functions<DequantizerQ4K>(kernels);
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ4K_AVX2, kernels);
//set_functions<DequantizerQ4K>(kernels);
break;
case GGML_TYPE_Q5_K:
set_functions<DequantizerQ5K>(kernels);
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ5K_AVX2, kernels);
//set_functions<DequantizerQ5K>(kernels);
break;
case GGML_TYPE_Q6_K:
set_functions<DequantizerQ6K>(kernels);
Expand Down Expand Up @@ -1778,6 +2062,15 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_

}

bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
switch (ggml_type(type)) {
case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;
}

#else
// --------------------------------- __aarch64__ --------------------------------------

Expand Down
2 changes: 2 additions & 0 deletions ggml/src/iqk/iqk_gemm_kquants.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_

void iqk_gemm_q8kv_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step);

bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x);

#endif
Loading