Skip to content

Commit c8f98b7

Browse files
committed
IKL_pr_up_to_542+541
1 parent 2576f83 commit c8f98b7

File tree

5 files changed

+189
-66
lines changed

5 files changed

+189
-66
lines changed

ggml/src/iqk/iqk_gemm_iqk_quants.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3908,6 +3908,23 @@ void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& in
39083908

39093909
}
39103910

3911+
bool iqk_convert_iqk_quants_q80_r8([[maybe_unused]] int type, int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, int nrc_x) {
3912+
if (n%QK_K != 0 || nrc_x%8 != 0) return false;
3913+
return false;
3914+
//switch (ggml_type(type)) {
3915+
// case GGML_TYPE_IQ2_KS : iqk_convert_iq2_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break;
3916+
// case GGML_TYPE_IQ2_K : iqk_convert_iq2_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3917+
// case GGML_TYPE_IQ3_K : iqk_convert_iq3_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3918+
// case GGML_TYPE_IQ4_KS : iqk_convert_iq4_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break;
3919+
// case GGML_TYPE_IQ4_K : iqk_convert_iq4_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3920+
// case GGML_TYPE_IQ5_KS : iqk_convert_iq5_ks_q8_k_r8(n, vx, bx, vy, nrc_x); break;
3921+
// case GGML_TYPE_IQ5_K : iqk_convert_iq5_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3922+
// case GGML_TYPE_IQ6_K : iqk_convert_iq6_k_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3923+
// default: return false;
3924+
//}
3925+
//return true;
3926+
}
3927+
39113928
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, [[maybe_unused]] mul_mat_t& func16) {
39123929

39133930
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {

ggml/src/iqk/iqk_gemm_iquants.cpp

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1839,37 +1839,34 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
18391839
}
18401840
}
18411841

1842-
inline float convert_to_q8_k_r8(int k, int d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
1842+
inline float convert_to_q8_k_r8(int k, float d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
18431843
auto max_i16 = _mm256_setzero_si256();
1844+
__m256i qs[16];
18441845
for (int ib32 = 0; ib32 < 8; ++ib32) {
1845-
auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
1846-
auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
1847-
q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(scales[2*ib32+0]));
1848-
q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(scales[2*ib32+1]));
1849-
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_l, q16_l));
1850-
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_h, q16_h));
1846+
qs[2*ib32+0] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
1847+
qs[2*ib32+1] = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
1848+
qs[2*ib32+0] = _mm256_mullo_epi16(qs[2*ib32+0], _mm256_set1_epi16(scales[2*ib32+0]));
1849+
qs[2*ib32+1] = _mm256_mullo_epi16(qs[2*ib32+1], _mm256_set1_epi16(scales[2*ib32+1]));
1850+
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+0], qs[2*ib32+0]));
1851+
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+1], qs[2*ib32+1]));
18511852
}
18521853
auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1)));
18531854
auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
18541855
auto max4 = _mm_cvtepi32_ps(imax4);
18551856
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
18561857
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
18571858
bool needs_scaling = true;
1858-
float dnew = _mm_cvtss_f32(max4) / d0;
1859+
float dnew = _mm_cvtss_f32(max4) * d0;
18591860
if (dnew < 1.f) {
18601861
dnew = 1.f; needs_scaling = false;
18611862
}
18621863
auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
18631864
for (int ib32 = 0; ib32 < 8; ++ib32) {
1864-
auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
1865-
auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
1866-
q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(scales[2*ib32+0]));
1867-
q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(scales[2*ib32+1]));
18681865
if (needs_scaling) {
1869-
auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
1870-
auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
1871-
auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
1872-
auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
1866+
auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+0]));
1867+
auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+0], 1));
1868+
auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+1]));
1869+
auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+1], 1));
18731870
i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
18741871
i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
18751872
i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
@@ -1881,7 +1878,7 @@ inline float convert_to_q8_k_r8(int k, int d0, const __m256i * qx, const int16_t
18811878
_mm256_storeu_si256((__m256i *)block, i0);
18821879
} else {
18831880
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
1884-
auto i0 = _mm256_packs_epi16(q16_l, q16_h);
1881+
auto i0 = _mm256_packs_epi16(qs[2*ib32+0], qs[2*ib32+1]);
18851882
auto i0_l = _mm256_castsi256_si128(i0);
18861883
auto i0_h = _mm256_extracti128_si256(i0, 1);
18871884
_mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h));
@@ -1976,7 +1973,7 @@ void iqk_convert_iq2_xxs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, i
19761973
values[ib32] = _mm256_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
19771974
esh.sign_value(aux32[1], values[ib32]);
19781975
}
1979-
float dnew = convert_to_q8_k_r8(k, 124, values, ls, block, y[i].qs);
1976+
float dnew = convert_to_q8_k_r8(k, 1.f/124, values, ls, block, y[i].qs);
19801977
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
19811978
}
19821979
}
@@ -2020,7 +2017,7 @@ void iqk_convert_iq2_xs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, in
20202017
DequantizerIQ2XS::sign_values_helper(q2l, sign_helper, qx+0);
20212018
DequantizerIQ2XS::sign_values_helper(q2h, sign_helper, qx+4);
20222019
#endif
2023-
float dnew = convert_to_q8_k_r8(k, 124, qx, helper.val, block, y[i].qs);
2020+
float dnew = convert_to_q8_k_r8(k, 1.f/124, qx, helper.val, block, y[i].qs);
20242021
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
20252022
}
20262023
}
@@ -2333,7 +2330,7 @@ void iqk_convert_iq2_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
23332330
helper.vec = DequantizerIQ2S::make_scales(x8[k][i].scales);
23342331
DequantizerIQ2S::prepare(x8[k][i].qs+ 0, x8[k][i].qh+0, (const uint16_t *)(x8[k][i].qs + QK_K/8) + 0, sh, qx+0);
23352332
DequantizerIQ2S::prepare(x8[k][i].qs+16, x8[k][i].qh+4, (const uint16_t *)(x8[k][i].qs + QK_K/8) + 8, sh, qx+4);
2336-
float dnew = convert_to_q8_k_r8(k, 124, qx, helper.val, block, y[i].qs);
2333+
float dnew = convert_to_q8_k_r8(k, 1.f/124, qx, helper.val, block, y[i].qs);
23372334
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
23382335
}
23392336
}
@@ -2480,7 +2477,6 @@ void iqk_convert_iq3_xxs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, i
24802477
for (int ix = 0; ix < nrc_x; ix += 8) {
24812478
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq3_xxs *)((const char *)vx + (ix + k)*bx);
24822479
for (int i = 0; i < nb; ++i) {
2483-
// TODO: simdify
24842480
for (int k = 0; k < 8; ++k) {
24852481
float d = 0.25f * GGML_FP16_TO_FP32(x8[k][i].d);
24862482
auto qs = x8[k][i].qs;
@@ -2494,7 +2490,7 @@ void iqk_convert_iq3_xxs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, i
24942490
esh.sign_value(aux32, values[ib32]);
24952491
qs += 8;
24962492
}
2497-
float dnew = convert_to_q8_k_r8(k, 124, values, ls, block, y[i].qs);
2493+
float dnew = convert_to_q8_k_r8(k, 1.f/124, values, ls, block, y[i].qs);
24982494
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
24992495
}
25002496
}
@@ -2589,7 +2585,7 @@ void iqk_convert_iq3_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
25892585
ls[2*ib32 + 0] = (2*((x8[k][i].scales[ib32/2] >> 4*(ib32%2)) & 0xf) + 1);
25902586
ls[2*ib32 + 1] = ls[2*ib32 + 0];
25912587
}
2592-
float dnew = convert_to_q8_k_r8(k, 127, values, ls, block, y[i].qs);
2588+
float dnew = convert_to_q8_k_r8(k, 1.f/127, values, ls, block, y[i].qs);
25932589
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
25942590
}
25952591
}
@@ -2666,7 +2662,6 @@ template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX
26662662

26672663
bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
26682664

2669-
// if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
26702665
if (ne00%QK_K != 0) return false;
26712666

26722667
//if (ggml_type(typeA) == GGML_TYPE_IQ2_XXS) {
@@ -3355,6 +3350,20 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
33553350

33563351
}
33573352

3353+
bool iqk_convert_iquants_q80_r8([[maybe_unused]] int type, int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, int nrc_x) {
3354+
if (n%QK_K != 0 || nrc_x%8 != 0) return false;
3355+
return false;
3356+
//switch (ggml_type(type)) {
3357+
// case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
3358+
// case GGML_TYPE_IQ2_XS : iqk_convert_iq2_xs_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3359+
// case GGML_TYPE_IQ2_S : iqk_convert_iq2_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3360+
// case GGML_TYPE_IQ3_XXS: iqk_convert_iq3_xxs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
3361+
// case GGML_TYPE_IQ3_S : iqk_convert_iq3_s_q8_k_r8 (n, vx, bx, vy, nrc_x); break;
3362+
// default: return false;
3363+
//}
3364+
//return true;
3365+
}
3366+
33583367
bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
33593368

33603369
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {

ggml/src/iqk/iqk_gemm_kquants.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3704,6 +3704,20 @@ void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& i
37043704

37053705
}
37063706

3707+
bool iqk_convert_kquants_q8X_r8([[maybe_unused]] int type, [[maybe_unused]] int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, [[maybe_unused]] int nrc_x) {
3708+
return false;
3709+
//switch (ggml_type(type)) {
3710+
// case GGML_TYPE_Q2_K: iqk_convert_q2_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
3711+
// case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
3712+
// case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
3713+
// case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
3714+
// case GGML_TYPE_Q6_K: iqk_convert_q6_k_q8_0_r8(n, vx, bx, vy, nrc_x); break;
3715+
// case GGML_TYPE_IQ4_XS: iqk_convert_iq4_xs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
3716+
// default: return false;
3717+
//}
3718+
//return true;
3719+
}
3720+
37073721
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, [[maybe_unused]] mul_mat_t& func16) {
37083722

37093723
auto etypeA = ggml_type(typeA);

0 commit comments

Comments
 (0)