Skip to content

Commit ccdecbb

Browse files
ikawrakowIwan Kawrakow
authored andcommitted
Much faster CPU prompt processing (part 3) (ikawrakow#534)
* Repack q4_0 and q8_0 to q8_0_R8 q8_0 is fine, but I observe a very significant PPL increase for q4_0. Best guess: precision loss with the 32 bit <-> 16 bit scale conversions. * Change q8_2_x4 to store in16_t sums With that q4_0 now works. I need to check all quants that use q8_2_x4! * q5_0 and use a dequntizing template * q6_0 129 t/s -> 296 t/s. q6_0_r4 is at 244 t/s. * iq4_nl 137 t/s -> 293 t/s. iq4_nl is at 251 t/s. * q4_1: 135 t/s -> 262 t/s * q5_1: 125 t/s -> 253 t/s * iq3_xs 178 t/s -> 363 t/s. iq4_xs_r4 is at 275 t/s. * q2_K 202 t/s -> 364 t/s. q2_k_r4 is at 247 t/s. --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 62ae359 commit ccdecbb

File tree

5 files changed

+441
-42
lines changed

5 files changed

+441
-42
lines changed

ggml/src/iqk/iqk_gemm_kquants.cpp

Lines changed: 183 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -810,10 +810,11 @@ static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const Data
810810
auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d)));
811811
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16));
812812
_mm256_storeu_ps(d8 + 8*iy, dy);
813-
auto m4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4)));
814-
auto m4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4)));
815-
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(m4_2, m4_1), 16));
816-
accd[iy] = _mm256_fmadd_ps(my, mins, accd[iy]);
813+
auto m4_1 = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4)));
814+
auto m4_2 = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4)));
815+
auto myi = MM256_SET_M128I(m4_2, m4_1);
816+
auto my = _mm256_mul_ps(dy, _mm256_cvtepi32_ps(myi));
817+
accd[iy] = _mm256_fmadd_ps(my, mins, accd[iy]);
817818
}
818819

819820
auto all_scales = _mm256_mul_ps(_mm256_set1_ps(deq.d), _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)utmp))));
@@ -2017,6 +2018,91 @@ typedef struct {
20172018
int8_t qs[8*QK8_1];
20182019
} block_q8_1_r8;
20192020

2021+
void iqk_convert_q2_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
2022+
GGML_ASSERT(n%QK_K == 0);
2023+
GGML_ASSERT(nrc_x%8 == 0);
2024+
2025+
int nb = n/QK_K;
2026+
2027+
const block_q2_K * x8[8];
2028+
2029+
block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
2030+
2031+
float f_values[QK_K];
2032+
uint32_t block[8];
2033+
2034+
__m256i xv[4];
2035+
2036+
auto ml = _mm256_set1_epi8(0x03);
2037+
auto sign_bit = _mm256_set1_ps(-0.0f);
2038+
auto perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
2039+
2040+
for (int ix = 0; ix < nrc_x; ix += 8) {
2041+
for (int k = 0; k < 8; ++k) x8[k] = (const block_q2_K *)((const char *)vx + (ix + k)*bx);
2042+
for (int i = 0; i < nb; ++i) {
2043+
for (int k = 0; k < 8; ++k) {
2044+
auto vd = _mm256_set1_ps(GGML_FP16_TO_FP32(x8[k][i].d));
2045+
auto vm = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x8[k][i].dmin)), _mm256_set1_ps(-1.f));
2046+
auto block_max = _mm256_setzero_ps();
2047+
for (int i128 = 0; i128 < 2; ++i128) {
2048+
auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+i128);
2049+
xv[0] = _mm256_and_si256(bits, ml);
2050+
xv[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), ml);
2051+
xv[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), ml);
2052+
xv[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), ml);
2053+
for (int l = 0; l < 4; ++l) {
2054+
auto q1 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(xv[l]));
2055+
auto q2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xv[l], 1));
2056+
q1 = _mm256_mullo_epi16(q1, _mm256_set1_epi16(x8[k][i].scales[8*i128 + 2*l + 0] & 0xf));
2057+
q2 = _mm256_mullo_epi16(q2, _mm256_set1_epi16(x8[k][i].scales[8*i128 + 2*l + 1] & 0xf));
2058+
auto m1 = _mm256_mul_ps(vm, _mm256_set1_ps(x8[k][i].scales[8*i128 + 2*l + 0] >> 4));
2059+
auto m2 = _mm256_mul_ps(vm, _mm256_set1_ps(x8[k][i].scales[8*i128 + 2*l + 1] >> 4));
2060+
auto v0 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q1))), vd, m1);
2061+
auto v1 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q1, 1))), vd, m1);
2062+
auto v2 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q2))), vd, m2);
2063+
auto v3 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q2, 1))), vd, m2);
2064+
auto max = _mm256_max_ps(_mm256_max_ps(_mm256_andnot_ps(sign_bit, v0), _mm256_andnot_ps(sign_bit, v1)),
2065+
_mm256_max_ps(_mm256_andnot_ps(sign_bit, v2), _mm256_andnot_ps(sign_bit, v3)));
2066+
block_max = _mm256_max_ps(block_max, max);
2067+
_mm256_storeu_ps(f_values + 128*i128 + 32*l + 0, v0);
2068+
_mm256_storeu_ps(f_values + 128*i128 + 32*l + 8, v1);
2069+
_mm256_storeu_ps(f_values + 128*i128 + 32*l + 16, v2);
2070+
_mm256_storeu_ps(f_values + 128*i128 + 32*l + 24, v3);
2071+
}
2072+
}
2073+
auto max4 = _mm_max_ps(_mm256_extractf128_ps(block_max, 1), _mm256_castps256_ps128(block_max));
2074+
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
2075+
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
2076+
float d = _mm_cvtss_f32(max4/127.f);
2077+
auto id = _mm256_set1_ps(d != 0.0f ? 1/d : 0.0f);
2078+
y[i].d[k] = GGML_FP32_TO_FP16(d);
2079+
for (int ib32 = 0; ib32 < 8; ++ib32) {
2080+
auto v0 = _mm256_loadu_ps(f_values + 32*ib32 + 0);
2081+
auto v1 = _mm256_loadu_ps(f_values + 32*ib32 + 8);
2082+
auto v2 = _mm256_loadu_ps(f_values + 32*ib32 + 16);
2083+
auto v3 = _mm256_loadu_ps(f_values + 32*ib32 + 24);
2084+
auto i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v0, id), _MM_ROUND_NEAREST));
2085+
auto i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v1, id), _MM_ROUND_NEAREST));
2086+
auto i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v2, id), _MM_ROUND_NEAREST));
2087+
auto i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v3, id), _MM_ROUND_NEAREST));
2088+
i0 = _mm256_packs_epi32(i0, i1);
2089+
i2 = _mm256_packs_epi32(i2, i3);
2090+
i0 = _mm256_packs_epi16(i0, i2);
2091+
i0 = _mm256_permutevar8x32_epi32(i0, perm);
2092+
2093+
_mm256_storeu_si256((__m256i *)block, i0);
2094+
auto q8 = (uint32_t *)y[i].qs + 64*ib32;
2095+
for (int l = 0; l < 4; ++l) {
2096+
q8[8*l + k + 0] = block[l + 0];
2097+
q8[8*l + k + 32] = block[l + 4];
2098+
}
2099+
}
2100+
}
2101+
}
2102+
y += nb;
2103+
}
2104+
}
2105+
20202106
void iqk_convert_q4_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
20212107
GGML_ASSERT(n%QK_K == 0);
20222108
GGML_ASSERT(nrc_x%8 == 0);
@@ -2429,6 +2515,97 @@ void iqk_convert_q3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
24292515
}
24302516
}
24312517

2518+
inline float convert_to_q8_k_r8(int k, float d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
2519+
auto max_i16 = _mm256_setzero_si256();
2520+
__m256i qs[16];
2521+
for (int ib32 = 0; ib32 < 8; ++ib32) {
2522+
qs[2*ib32+0] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
2523+
qs[2*ib32+1] = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
2524+
qs[2*ib32+0] = _mm256_mullo_epi16(qs[2*ib32+0], _mm256_set1_epi16(scales[2*ib32+0]));
2525+
qs[2*ib32+1] = _mm256_mullo_epi16(qs[2*ib32+1], _mm256_set1_epi16(scales[2*ib32+1]));
2526+
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+0], qs[2*ib32+0]));
2527+
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+1], qs[2*ib32+1]));
2528+
}
2529+
auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1)));
2530+
auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
2531+
auto max4 = _mm_cvtepi32_ps(imax4);
2532+
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
2533+
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
2534+
bool needs_scaling = true;
2535+
float dnew = _mm_cvtss_f32(max4) * d0;
2536+
if (dnew < 1.f) {
2537+
dnew = 1.f; needs_scaling = false;
2538+
}
2539+
auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
2540+
for (int ib32 = 0; ib32 < 8; ++ib32) {
2541+
if (needs_scaling) {
2542+
auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+0]));
2543+
auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+0], 1));
2544+
auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+1]));
2545+
auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+1], 1));
2546+
i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
2547+
i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
2548+
i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
2549+
i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
2550+
i0 = _mm256_packs_epi32(i0, i1);
2551+
i2 = _mm256_packs_epi32(i2, i3);
2552+
i0 = _mm256_packs_epi16(i0, i2);
2553+
i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
2554+
_mm256_storeu_si256((__m256i *)block, i0);
2555+
} else {
2556+
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
2557+
auto i0 = _mm256_packs_epi16(qs[2*ib32+0], qs[2*ib32+1]);
2558+
auto i0_l = _mm256_castsi256_si128(i0);
2559+
auto i0_h = _mm256_extracti128_si256(i0, 1);
2560+
_mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h));
2561+
_mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h));
2562+
}
2563+
auto qs = (uint32_t *)q8_k + 64*ib32;
2564+
for (int l = 0; l < 8; ++l) {
2565+
qs[8*l + k] = block[l];
2566+
}
2567+
}
2568+
return dnew;
2569+
}
2570+
2571+
// TODO: move this to iqk_gemm_iquants
2572+
void iqk_convert_iq4_xs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
2573+
GGML_ASSERT(n%QK_K == 0);
2574+
GGML_ASSERT(nrc_x%8 == 0);
2575+
2576+
int nb = n/QK_K;
2577+
2578+
const block_iq4_xs * x8[8];
2579+
2580+
block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
2581+
2582+
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
2583+
auto values = MM256_SET_M128I(values128, values128);
2584+
2585+
int16_t ls[16];
2586+
float dnew[8];
2587+
uint32_t block[8];
2588+
__m256i xv[8];
2589+
2590+
for (int ix = 0; ix < nrc_x; ix += 8) {
2591+
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq4_xs *)((const char *)vx + (ix + k)*bx);
2592+
for (int i = 0; i < nb; ++i) {
2593+
for (int k = 0; k < 8; ++k) {
2594+
float d = GGML_FP16_TO_FP32(x8[k][i].d);
2595+
for (int ib32 = 0; ib32 < 8; ++ib32) {
2596+
ls[2*ib32+0] = ls[2*ib32+1] = (((x8[k][i].scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((x8[k][i].scales_h >> 2*ib32) & 3) << 4)) - 32;
2597+
auto bits = _mm_loadu_si128((const __m128i *)x8[k][i].qs + ib32);
2598+
xv[ib32] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(bits, 4), bits), _mm256_set1_epi8(0xf));
2599+
xv[ib32] = _mm256_shuffle_epi8(values, xv[ib32]);
2600+
}
2601+
dnew[k] = d * convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs);
2602+
}
2603+
_mm_storeu_si128((__m128i *)y[i].d, _mm256_cvtps_ph(_mm256_loadu_ps(dnew), _MM_ROUND_NEAREST));
2604+
}
2605+
y += nb;
2606+
}
2607+
}
2608+
24322609

24332610
} // namespace
24342611

@@ -2514,10 +2691,12 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
25142691

25152692
bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
25162693
switch (ggml_type(type)) {
2694+
case GGML_TYPE_Q2_K: iqk_convert_q2_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
25172695
case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
25182696
case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
25192697
case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
25202698
case GGML_TYPE_Q6_K: iqk_convert_q6_k_q8_0_r8(n, vx, bx, vy, nrc_x); break;
2699+
case GGML_TYPE_IQ4_XS: iqk_convert_iq4_xs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
25212700
default: return false;
25222701
}
25232702
return true;

0 commit comments

Comments
 (0)