Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_iq3_xxs,
.from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref,
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
#ifdef __AVX2__
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_K,
#endif
.nrows = 1,
.row_meta_size = 0,
},
Expand Down
72 changes: 71 additions & 1 deletion ggml/src/iqk/iqk_gemm_iquants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,15 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
auto scales16 = prepare_scales(i);
scales[0] = MM256_SET_M128I(scales16, scales16);
}
inline void new_block_f(int i, __m256 * scales) {
auto sc16 = prepare_scales(i);
auto scf = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(sc16)));
auto scf_l = _mm256_castps256_ps128(scf);
auto scf_h = _mm256_extractf128_ps(scf, 1);
scales[0] = _mm256_set_m128(scf_l, scf_l);
scales[1] = _mm256_set_m128(scf_h, scf_h);
scales[2] = _mm256_mul_ps(scf, _mm256_set1_ps(-minv));
}
inline float new_block(int i, __m256i * scales, __m256i& mins) {
auto scales16 = prepare_scales(i);
mins = scb.shuffle(scales16);
Expand Down Expand Up @@ -1771,6 +1780,58 @@ void iqk_convert_iq2_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, i
}
}

void iqk_convert_iq3_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);

int nb = n/QK_K;

const block_iq3_xxs * x8[8];

block_q8_0_r8 * y = (block_q8_0_r8 *)vy;

ggml_half dh[8];
uint16_t all_ls[64];
EvenSignHelper esh;

uint32_t block[8];
uint32_t aux32;

for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq3_xxs *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
// TODO: simdify
for (int k = 0; k < 8; ++k) {
dh[k] = x8[k][i].d;
auto qs = x8[k][i].qs;
auto sas = qs + QK_K/4;
for (int ib32 = 0; ib32 < 8; ++ib32) {
std::memcpy(&aux32, sas + 4*ib32, sizeof(uint32_t));
all_ls[8*ib32 + k] = (2*(aux32 >> 28) + 1);
auto value = _mm256_set_epi32(iq3xxs_grid[qs[7]], iq3xxs_grid[qs[6]], iq3xxs_grid[qs[5]], iq3xxs_grid[qs[4]],
iq3xxs_grid[qs[3]], iq3xxs_grid[qs[2]], iq3xxs_grid[qs[1]], iq3xxs_grid[qs[0]]);
esh.sign_value(aux32, value);
_mm256_storeu_si256((__m256i *)block, value);
auto q8 = (uint32_t *)y[ib32].qs;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
qs += 8;
}
}
auto vd = _mm256_mul_ps(_mm256_set1_ps(0.25f), _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh)));
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
}
y += QK_K/32;
}
}
}

template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
Expand All @@ -1791,7 +1852,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
if (ggml_type(typeA) == GGML_TYPE_IQ2_XXS) {
if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ2XXS, kernels);
//IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_q8_2_IQ_N, kernels);
func16 = nullptr;
return true;
}
return false;
}

if (ggml_type(typeA) == GGML_TYPE_IQ3_XXS) {
if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3XXS, kernels);
func16 = nullptr;
return true;
}
Expand Down Expand Up @@ -1856,6 +1925,7 @@ bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, voi
if (n%QK_K != 0 || nrc_x%8 != 0) return false;
switch (ggml_type(type)) {
case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_0_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ3_XXS: iqk_convert_iq3_xxs_q8_0_r8(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;
Expand Down
1 change: 1 addition & 0 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ struct MulMat {
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
default: break;
}
#else
Expand Down