Skip to content

Commit 3b94f0a

Browse files
ikawrakowIwan Kawrakow
andauthored
AVX512+AVXVNNI GEMM implementation for quants using Q8_K for activations (#710)
* q8_k_r16: basics * q8_k_r16: iq4_xs now uses q8_k_r16 on Zen4+ PP performance is about the same as using q8_k_r8 on the Ryzen-7950X, so we expect nice gains on Zen5, and we don't need to wory about using 2 different q8_k_r8 implementations for fancy SIMD. * q8_k_r16: iq2_xxs now uses q8_k_r16 on Zen4+ * q8_k_r16: iq2_xs now uses q8_k_r16 on Zen4+ * q8_k_r16: iq2_s now uses q8_k_r16 on Zen4+ * q8_k_r16: iq3_xxs now uses q8_k_r16 on Zen4+ * q8_k_r16: iq3_s now uses q8_k_r16 on Zen4+ * q8_k_r16: iq1_s and iq1_m now uses q8_k_r16 on Zen4+ * q8_k_r16: q2_K and q3_K now uses q8_k_r16 on Zen4+ * q8_k_r16: iq2_ks and iq2_k now uses q8_k_r16 on Zen4+ * q8_k_r16: iq2_kl now uses q8_k_r16 on Zen4+ * q8_k_r16: iq3_ks and iq3_k now uses q8_k_r16 on Zen4+ * q8_k_r16: iq4_kss, iq4_ks, and iq4_k now use q8_k_r16 on Zen4+ * q8_k_r16: iq5_ks, iq5_k, and iq6_k now use q8_k_r16 on Zen4+ * Fix AVX2 * Just always set num_rows to 16 --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent c8e4d66 commit 3b94f0a

File tree

12 files changed

+826
-489
lines changed

12 files changed

+826
-489
lines changed

ggml/include/ggml.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ extern "C" {
475475
GGML_TYPE_IQ5_K_R4 = 340,
476476
GGML_TYPE_IQ4_KS_R4 = 344,
477477
GGML_TYPE_IQ5_KS_R4 = 352,
478+
GGML_TYPE_Q8_K_R16 = 397,
478479
GGML_TYPE_Q8_KV_R8 = 398,
479480
GGML_TYPE_Q8_K_R8 = 399,
480481
GGML_TYPE_COUNT,
@@ -571,6 +572,7 @@ extern "C" {
571572
GGML_FTYPE_MOSTLY_IQ5_K_R4 = 333, // except 1d tensors
572573
GGML_FTYPE_MOSTLY_IQ4_KS_R4 = 337, // except 1d tensors
573574
GGML_FTYPE_MOSTLY_IQ5_KS_R4 = 341, // except 1d tensors
575+
GGML_FTYPE_MOSTLY_Q8_K_R16 = 397, // except 1d tensors
574576
GGML_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors
575577
GGML_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors
576578
};

ggml/src/ggml-common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,12 @@ typedef struct {
421421
} block_q8_k_r8;
422422
static_assert(sizeof(block_q8_k_r8) == 8*sizeof(ggml_half) + 8*QK_K, "wrong q8_k_r8 block size/padding");
423423

424+
typedef struct {
425+
ggml_half d[16]; // delta
426+
int8_t qs[16*QK_K]; // quants, stored as unsigned ints
427+
} block_q8_k_r16;
428+
static_assert(sizeof(block_q8_k_r16) == 16*sizeof(ggml_half) + 16*QK_K, "wrong q8_k_r16 block size/padding");
429+
424430
// (Almost) "true" 2-bit quantization.
425431
// Due to the need to use blocks as per ggml design, it ends up using
426432
// 2.0625 bpw because of the 16-bit scale for each block of 256.

ggml/src/ggml-quants.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15461,6 +15461,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
1546115461
case GGML_TYPE_IQ5_KS_R4:break;
1546215462
case GGML_TYPE_Q8_KV_R8: break;
1546315463
case GGML_TYPE_Q8_K_R8: break;
15464+
case GGML_TYPE_Q8_K_R16: break;
1546415465
case GGML_TYPE_Q8_KV: break;
1546515466
case GGML_TYPE_BF16_R16: break;
1546615467
case GGML_TYPE_Q4_0_4_4:

ggml/src/ggml.c

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
10711071
.nrows = 1,
10721072
.row_meta_size = 0,
10731073
},
1074+
[GGML_TYPE_Q8_K_R16] = {
1075+
.type_name = "q8_k_r16",
1076+
.blck_size = QK_K,
1077+
.type_size = sizeof(block_q8_k_r16)/16,
1078+
.is_quantized = true,
1079+
.to_float = (ggml_to_float_t) dequantize_row_q8_k_r16,
1080+
.from_float = quantize_row_q8_k_r16,
1081+
.from_float_ref = (ggml_from_float_t) quantize_row_q8_k_r16_ref,
1082+
.vec_dot = vec_dot_q8_k_r16_q8_k,
1083+
.vec_dot_type = GGML_TYPE_Q8_K,
1084+
.nrows = 1,
1085+
.row_meta_size = 0,
1086+
},
10741087
[GGML_TYPE_IQ2_XXS] = {
10751088
.type_name = "iq2_xxs",
10761089
.blck_size = QK_K,
@@ -1934,7 +1947,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
19341947
}
19351948

19361949
static inline int ggml_packed_rows(enum ggml_type type) {
1937-
return type == GGML_TYPE_BF16_R16 ? 16
1950+
return type == GGML_TYPE_BF16_R16 || type == GGML_TYPE_Q8_K_R16 ? 16
19381951
: type == GGML_TYPE_Q8_K_R8 || type == GGML_TYPE_Q8_KV_R8 ||
19391952
type == GGML_TYPE_Q8_0_R8 || type == GGML_TYPE_Q4_0_R8 ||
19401953
type == GGML_TYPE_IQ4_XS_R8 ? 8
@@ -4617,6 +4630,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
46174630
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
46184631
case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break;
46194632
case GGML_FTYPE_MOSTLY_Q8_K_R8: wtype = GGML_TYPE_Q8_K_R8; break;
4633+
case GGML_FTYPE_MOSTLY_Q8_K_R16: wtype = GGML_TYPE_Q8_K_R16; break;
46204634
case GGML_FTYPE_MOSTLY_Q8_KV_R8: wtype = GGML_TYPE_Q8_KV_R8; break;
46214635
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
46224636
case GGML_FTYPE_MOSTLY_IQ2_XXS_R4: wtype = GGML_TYPE_IQ2_XXS_R4;break;
@@ -11542,6 +11556,7 @@ static void ggml_compute_forward_add(
1154211556
case GGML_TYPE_Q6_K:
1154311557
case GGML_TYPE_Q6_K_R4:
1154411558
case GGML_TYPE_Q8_K_R8:
11559+
case GGML_TYPE_Q8_K_R16:
1154511560
case GGML_TYPE_Q8_KV_R8:
1154611561
case GGML_TYPE_IQ2_XXS:
1154711562
case GGML_TYPE_IQ2_XXS_R4:
@@ -12094,6 +12109,7 @@ static void ggml_compute_forward_add1(
1209412109
case GGML_TYPE_Q6_K:
1209512110
case GGML_TYPE_Q6_K_R4:
1209612111
case GGML_TYPE_Q8_K_R8:
12112+
case GGML_TYPE_Q8_K_R16:
1209712113
case GGML_TYPE_Q8_KV_R8:
1209812114
case GGML_TYPE_IQ2_XXS:
1209912115
case GGML_TYPE_IQ2_XXS_R4:
@@ -12272,6 +12288,7 @@ static void ggml_compute_forward_acc(
1227212288
case GGML_TYPE_Q6_K:
1227312289
case GGML_TYPE_Q6_K_R4:
1227412290
case GGML_TYPE_Q8_K_R8:
12291+
case GGML_TYPE_Q8_K_R16:
1227512292
case GGML_TYPE_Q8_KV_R8:
1227612293
case GGML_TYPE_IQ2_XXS:
1227712294
case GGML_TYPE_IQ2_XXS_R4:
@@ -14966,6 +14983,17 @@ static void ggml_compute_forward_mul_mat(
1496614983
#endif
1496714984

1496814985
#if GGML_USE_IQK_MULMAT
14986+
if (ith == 0) {
14987+
static bool first_time = true;
14988+
if (first_time) {
14989+
first_time = false;
14990+
#ifdef HAVE_FANCY_SIMD
14991+
printf("======================================= HAVE_FANCY_SIMD is defined\n");
14992+
#else
14993+
printf("======================================= HAVE_FANCY_SIMD is NOT defined\n");
14994+
#endif
14995+
}
14996+
}
1496914997
if (dst->type == GGML_TYPE_F32) {
1497014998
if (iqk_mul_mat_4d(ne01, ne11, ne00,
1497114999
ne02, ne03, ne12, ne13, nb02, nb03, nb12, nb13, nb2/sizeof(float), nb3/sizeof(float),
@@ -15872,6 +15900,7 @@ static void ggml_compute_forward_out_prod(
1587215900
case GGML_TYPE_Q6_K:
1587315901
case GGML_TYPE_Q6_K_R4:
1587415902
case GGML_TYPE_Q8_K_R8:
15903+
case GGML_TYPE_Q8_K_R16:
1587515904
case GGML_TYPE_Q8_KV_R8:
1587615905
case GGML_TYPE_IQ2_XXS:
1587715906
case GGML_TYPE_IQ2_XXS_R4:
@@ -16290,6 +16319,7 @@ static void ggml_compute_forward_set(
1629016319
case GGML_TYPE_Q6_K:
1629116320
case GGML_TYPE_Q6_K_R4:
1629216321
case GGML_TYPE_Q8_K_R8:
16322+
case GGML_TYPE_Q8_K_R16:
1629316323
case GGML_TYPE_Q8_KV_R8:
1629416324
case GGML_TYPE_IQ2_XXS:
1629516325
case GGML_TYPE_IQ2_XXS_R4:
@@ -16614,6 +16644,7 @@ static void ggml_compute_forward_get_rows(
1661416644
case GGML_TYPE_Q6_K:
1661516645
case GGML_TYPE_Q6_K_R4:
1661616646
case GGML_TYPE_Q8_K_R8:
16647+
case GGML_TYPE_Q8_K_R16:
1661716648
case GGML_TYPE_Q8_KV_R8:
1661816649
case GGML_TYPE_IQ2_XXS:
1661916650
case GGML_TYPE_IQ2_XXS_R4:
@@ -17274,6 +17305,7 @@ static void ggml_compute_forward_clamp(
1727417305
case GGML_TYPE_Q6_K:
1727517306
case GGML_TYPE_Q6_K_R4:
1727617307
case GGML_TYPE_Q8_K_R8:
17308+
case GGML_TYPE_Q8_K_R16:
1727717309
case GGML_TYPE_Q8_KV_R8:
1727817310
case GGML_TYPE_Q8_KR8:
1727917311
case GGML_TYPE_IQ2_XXS:
@@ -24380,6 +24412,7 @@ size_t ggml_quantize_chunk(
2438024412
case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2438124413
case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2438224414
case GGML_TYPE_Q8_K_R8: result = quantize_q8_k_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
24415+
case GGML_TYPE_Q8_K_R16:result = quantize_q8_k_r16(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2438324416
case GGML_TYPE_Q8_KV_R8:result = quantize_q8_KV_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2438424417
case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
2438524418
case GGML_TYPE_IQ2_XXS_R4:result = quantize_iq2_xxs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;

ggml/src/iqk/iqk_common.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,60 @@ inline void iqk_transpose_8x8(__m256 * m) {
556556
}
557557
}
558558

559+
template <int nr = 8>
560+
static inline float convert_to_q8_k_r8(int k, float d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
561+
auto max_i16 = _mm256_setzero_si256();
562+
__m256i qs[16];
563+
for (int ib32 = 0; ib32 < 8; ++ib32) {
564+
qs[2*ib32+0] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
565+
qs[2*ib32+1] = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
566+
qs[2*ib32+0] = _mm256_mullo_epi16(qs[2*ib32+0], _mm256_set1_epi16(scales[2*ib32+0]));
567+
qs[2*ib32+1] = _mm256_mullo_epi16(qs[2*ib32+1], _mm256_set1_epi16(scales[2*ib32+1]));
568+
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+0], qs[2*ib32+0]));
569+
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+1], qs[2*ib32+1]));
570+
}
571+
auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1)));
572+
auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
573+
auto max4 = _mm_cvtepi32_ps(imax4);
574+
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
575+
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
576+
bool needs_scaling = true;
577+
float dnew = _mm_cvtss_f32(max4) * d0;
578+
if (dnew < 1.f) {
579+
dnew = 1.f; needs_scaling = false;
580+
}
581+
auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
582+
for (int ib32 = 0; ib32 < 8; ++ib32) {
583+
if (needs_scaling) {
584+
auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+0]));
585+
auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+0], 1));
586+
auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+1]));
587+
auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+1], 1));
588+
i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
589+
i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
590+
i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
591+
i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
592+
i0 = _mm256_packs_epi32(i0, i1);
593+
i2 = _mm256_packs_epi32(i2, i3);
594+
i0 = _mm256_packs_epi16(i0, i2);
595+
i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
596+
_mm256_storeu_si256((__m256i *)block, i0);
597+
} else {
598+
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
599+
auto i0 = _mm256_packs_epi16(qs[2*ib32+0], qs[2*ib32+1]);
600+
auto i0_l = _mm256_castsi256_si128(i0);
601+
auto i0_h = _mm256_extracti128_si256(i0, 1);
602+
_mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h));
603+
_mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h));
604+
}
605+
auto qs = (uint32_t *)q8_k + 8*nr*ib32;
606+
for (int l = 0; l < 8; ++l) {
607+
qs[nr*l + k] = block[l];
608+
}
609+
}
610+
return dnew;
611+
}
612+
559613
#else
560614
// ------------------------------------ __aarch64__ --------------------------------------------------
561615

ggml/src/iqk/iqk_gemm_1bit.cpp

Lines changed: 41 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,82 +1668,34 @@ static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const Da
16681668
}
16691669
#endif
16701670

1671-
inline float convert_to_q8_k_r8(int k, int d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
1672-
auto max_i16 = _mm256_setzero_si256();
1673-
for (int ib32 = 0; ib32 < 8; ++ib32) {
1674-
auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
1675-
auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
1676-
q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(scales[2*ib32+0]));
1677-
q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(scales[2*ib32+1]));
1678-
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_l, q16_l));
1679-
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_h, q16_h));
1680-
}
1681-
auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1)));
1682-
auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
1683-
auto max4 = _mm_cvtepi32_ps(imax4);
1684-
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
1685-
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
1686-
bool needs_scaling = true;
1687-
float dnew = _mm_cvtss_f32(max4) / d0;
1688-
if (dnew < 1.f) {
1689-
dnew = 1.f; needs_scaling = false;
1690-
}
1691-
auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
1692-
for (int ib32 = 0; ib32 < 8; ++ib32) {
1693-
auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
1694-
auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
1695-
q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(scales[2*ib32+0]));
1696-
q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(scales[2*ib32+1]));
1697-
if (needs_scaling) {
1698-
auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
1699-
auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
1700-
auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
1701-
auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
1702-
i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
1703-
i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
1704-
i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
1705-
i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
1706-
i0 = _mm256_packs_epi32(i0, i1);
1707-
i2 = _mm256_packs_epi32(i2, i3);
1708-
i0 = _mm256_packs_epi16(i0, i2);
1709-
i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
1710-
_mm256_storeu_si256((__m256i *)block, i0);
1711-
} else {
1712-
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
1713-
auto i0 = _mm256_packs_epi16(q16_l, q16_h);
1714-
auto i0_l = _mm256_castsi256_si128(i0);
1715-
auto i0_h = _mm256_extracti128_si256(i0, 1);
1716-
_mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h));
1717-
_mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h));
1718-
}
1719-
auto qs = (uint32_t *)q8_k + 64*ib32;
1720-
for (int l = 0; l < 8; ++l) {
1721-
qs[8*l + k] = block[l];
1722-
}
1723-
}
1724-
return dnew;
1725-
}
1726-
17271671
void iqk_convert_iq1_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1672+
#ifdef HAVE_FANCY_SIMD
1673+
constexpr int k_nr = 16;
1674+
using block_q8_k_r = block_q8_k_r16;
1675+
#else
1676+
constexpr int k_nr = 8;
1677+
using block_q8_k_r = block_q8_k_r8;
1678+
#endif
1679+
17281680
GGML_ASSERT(n%QK_K == 0);
1729-
GGML_ASSERT(nrc_x%8 == 0);
1681+
GGML_ASSERT(nrc_x%k_nr == 0);
17301682

17311683
int nb = n/QK_K;
17321684

1733-
const block_iq1_s * x8[8];
1685+
const block_iq1_s * x8[k_nr];
17341686

1735-
block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
1687+
block_q8_k_r * y = (block_q8_k_r *)vy;
17361688

17371689
int16_t ls[16];
17381690

17391691
uint32_t block[8];
17401692

17411693
__m256i qx[8];
17421694

1743-
for (int ix = 0; ix < nrc_x; ix += 8) {
1744-
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq1_s *)((const char *)vx + (ix + k)*bx);
1695+
for (int ix = 0; ix < nrc_x; ix += k_nr) {
1696+
for (int k = 0; k < k_nr; ++k) x8[k] = (const block_iq1_s *)((const char *)vx + (ix + k)*bx);
17451697
for (int i = 0; i < nb; ++i) {
1746-
for (int k = 0; k < 8; ++k) {
1698+
for (int k = 0; k < k_nr; ++k) {
17471699
float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d);
17481700
auto qs = x8[k][i].qs;
17491701
auto qh = x8[k][i].qh;
@@ -1759,23 +1711,36 @@ void iqk_convert_iq1_s_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
17591711
qx[ib32] = value;
17601712
qs += 4;
17611713
}
1762-
float dnew = convert_to_q8_k_r8(k, 126, qx, ls, block, y[i].qs);
1714+
float dnew = convert_to_q8_k_r8<k_nr>(k, 1.f/126, qx, ls, block, y[i].qs);
17631715
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
17641716
}
1717+
#ifdef HAVE_FANCY_SIMD
1718+
for (int l = 0; l < 64; ++l) {
1719+
auto v = _mm512_xor_si512(_mm512_loadu_si512((const __m512i *)y[i].qs + l), _mm512_set1_epi8(-128));
1720+
_mm512_storeu_si512((__m512i *)y[i].qs + l, v);
1721+
}
1722+
#endif
17651723
}
17661724
y += nb;
17671725
}
17681726
}
17691727

17701728
void iqk_convert_iq1_m_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1729+
#ifdef HAVE_FANCY_SIMD
1730+
constexpr int k_nr = 16;
1731+
using block_q8_k_r = block_q8_k_r16;
1732+
#else
1733+
constexpr int k_nr = 8;
1734+
using block_q8_k_r = block_q8_k_r8;
1735+
#endif
17711736
GGML_ASSERT(n%QK_K == 0);
1772-
GGML_ASSERT(nrc_x%8 == 0);
1737+
GGML_ASSERT(nrc_x%k_nr == 0);
17731738

17741739
int nb = n/QK_K;
17751740

1776-
const block_iq1_m * x8[8];
1741+
const block_iq1_m * x8[k_nr];
17771742

1778-
block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
1743+
block_q8_k_r * y = (block_q8_k_r *)vy;
17791744

17801745
int16_t ls[16];
17811746

@@ -1785,10 +1750,10 @@ void iqk_convert_iq1_m_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
17851750

17861751
auto mask = _mm256_setr_epi32(0x00000008, 0x00000008, 0x00000080, 0x00000080, 0x00080000, 0x00080000, 0x00800000, 0x00800000);
17871752

1788-
for (int ix = 0; ix < nrc_x; ix += 8) {
1789-
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq1_m *)((const char *)vx + (ix + k)*bx);
1753+
for (int ix = 0; ix < nrc_x; ix += k_nr) {
1754+
for (int k = 0; k < k_nr; ++k) x8[k] = (const block_iq1_m *)((const char *)vx + (ix + k)*bx);
17901755
for (int i = 0; i < nb; ++i) {
1791-
for (int k = 0; k < 8; ++k) {
1756+
for (int k = 0; k < k_nr; ++k) {
17921757
const uint16_t * sc = (const uint16_t *)x8[k][i].scales;
17931758
iq1m_scale_t scale;
17941759
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
@@ -1816,9 +1781,15 @@ void iqk_convert_iq1_m_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int
18161781
qs += 4;
18171782
qh += 2;
18181783
}
1819-
float dnew = convert_to_q8_k_r8(k, 126, qx, ls, block, y[i].qs);
1784+
float dnew = convert_to_q8_k_r8<k_nr>(k, 1.f/126, qx, ls, block, y[i].qs);
18201785
y[i].d[k] = GGML_FP32_TO_FP16(d*dnew);
18211786
}
1787+
#ifdef HAVE_FANCY_SIMD
1788+
for (int l = 0; l < 64; ++l) {
1789+
auto v = _mm512_xor_si512(_mm512_loadu_si512((const __m512i *)y[i].qs + l), _mm512_set1_epi8(-128));
1790+
_mm512_storeu_si512((__m512i *)y[i].qs + l, v);
1791+
}
1792+
#endif
18221793
}
18231794
y += nb;
18241795
}

0 commit comments

Comments
 (0)