Skip to content

Commit 3f54b49

Browse files
ikawrakowIwan Kawrakow
andauthored
Faster iq1_s GEMM via repacking to Q8_0_R8 (#517)
TG is slightly faster too - 24.4 vs 23.1 t/s on the Ryzen-5975WX Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 69af3f5 commit 3f54b49

File tree

4 files changed

+159
-9
lines changed

4 files changed

+159
-9
lines changed

ggml/src/ggml.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
12051205
.from_float = quantize_row_iq1_s,
12061206
.from_float_ref = (ggml_from_float_t)quantize_row_iq1_s_ref,
12071207
.vec_dot = ggml_vec_dot_iq1_s_q8_K,
1208+
#ifdef __AVX2__
1209+
.vec_dot_type = GGML_TYPE_Q8_2_X4,
1210+
#else
12081211
.vec_dot_type = GGML_TYPE_Q8_K,
1212+
#endif
12091213
.nrows = 1,
12101214
.row_meta_size = 0,
12111215
},

ggml/src/iqk/iqk_gemm_1bit.cpp

Lines changed: 147 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,80 @@ void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info,
865865
}
866866
}
867867

868+
template <int nrc_y>
869+
void mul_mat_iq1_s_q8_2_x4(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
870+
GGML_ASSERT(n%QK_K == 0);
871+
Q8<nrc_y, block_q8_2_x4> q8(info);
872+
__m256i qx[4];
873+
__m256 scales[2];
874+
__m256 acc[nrc_y] = {};
875+
auto delta_mask = _mm_set1_epi16(-32768); // to avoid stupid overflow warnings when using 0x8000
876+
for (int ix = 0; ix < nrc_x; ++ix) {
877+
auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);
878+
for (int ibl = 0; ibl < n/QK_K; ++ibl) {
879+
float d = GGML_FP16_TO_FP32(iq1s[ibl].d);
880+
auto qhb = _mm_loadu_si128((const __m128i *)iq1s[ibl].qh);
881+
auto scales128 = _mm_and_si128(_mm_srli_epi16(qhb, 12), _mm_set1_epi16(7));
882+
scales128 = _mm_add_epi16(_mm_slli_epi16(scales128, 1), _mm_set1_epi16(1));
883+
auto all_scales = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(scales128)));
884+
#ifdef HAVE_FANCY_SIMD
885+
auto mask = _mm_cmpeq_epi16_mask(_mm_and_si128(qhb, delta_mask), delta_mask);
886+
auto deltas128 = _mm_mask_blend_epi16(mask, _mm_set1_epi16(-7), _mm_set1_epi16(-9));
887+
#else
888+
auto mask = _mm_cmpeq_epi16(_mm_and_si128(qhb, delta_mask), delta_mask);
889+
auto deltas128 = _mm_or_si128(_mm_and_si128(mask, _mm_set1_epi16(-9)), _mm_andnot_si128(mask, _mm_set1_epi16(-7)));
890+
#endif
891+
auto deltas = _mm256_mul_ps(all_scales, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(deltas128)));
892+
for (int iy = 0; iy < nrc_y; ++iy) {
893+
auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*ibl+0].d + 4)));
894+
auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*ibl+1].d + 4)));
895+
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
896+
acc[iy] = _mm256_fmadd_ps(deltas, my, acc[iy]);
897+
}
898+
all_scales = _mm256_mul_ps(_mm256_set1_ps(8.f), all_scales);
899+
auto scales_l = _mm256_castps256_ps128(all_scales);
900+
auto scales_h = _mm256_extractf128_ps(all_scales, 1);
901+
scales[0] = _mm256_set_m128(scales_l, scales_l);
902+
scales[1] = _mm256_set_m128(scales_h, scales_h);
903+
const uint8_t * qs = iq1s[ibl].qs;
904+
const uint16_t * qh = iq1s[ibl].qh;
905+
for (int i128 = 0; i128 < QK_K/128; ++i128) {
906+
qx[0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[0] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[0] << 2) & 0x700)],
907+
iq1s_grid_us[qs[1] | ((qh[0] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[0] << 8) & 0x700)]);
908+
qx[1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[1] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[1] << 2) & 0x700)],
909+
iq1s_grid_us[qs[5] | ((qh[1] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[1] << 8) & 0x700)]);
910+
qs += 8;
911+
qx[2] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[2] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[2] << 2) & 0x700)],
912+
iq1s_grid_us[qs[1] | ((qh[2] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[2] << 8) & 0x700)]);
913+
qx[3] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[3] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[3] << 2) & 0x700)],
914+
iq1s_grid_us[qs[5] | ((qh[3] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[3] << 8) & 0x700)]);
915+
qs += 8; qh += 4;
916+
for (int iy = 0; iy < nrc_y; ++iy) {
917+
auto& ybl = q8.y[iy][2*ibl+i128];
918+
auto sumi1 = _mm256_maddubs_epi16(qx[0], _mm256_loadu_si256((const __m256i *)ybl.qs+0));
919+
auto sumi2 = _mm256_maddubs_epi16(qx[1], _mm256_loadu_si256((const __m256i *)ybl.qs+1));
920+
auto sumi3 = _mm256_maddubs_epi16(qx[2], _mm256_loadu_si256((const __m256i *)ybl.qs+2));
921+
auto sumi4 = _mm256_maddubs_epi16(qx[3], _mm256_loadu_si256((const __m256i *)ybl.qs+3));
922+
// 0,0,1,1, 0,0,1,1, 0,0,1,1, 0,0,1,1 as int16_t
923+
sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
924+
// 2,2,3,3, 2,2,3,3, 2,2,3,3, 2,2,3,3 as int16_t
925+
sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
926+
sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
927+
// 0, 1, 2, 3, 0, 1, 2, 3 as int322_t
928+
sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1);
929+
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)ybl.d)), 16));
930+
auto dy = _mm256_set_m128(d4, d4);
931+
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[i128], dy), _mm256_cvtepi32_ps(sumi1), acc[iy]);
932+
}
933+
}
934+
}
935+
for (int iy = 0; iy < nrc_y; ++iy) {
936+
info.store(ix, iy, 0.125f*hsum_float_8(acc[iy]));
937+
acc[iy] = _mm256_setzero_ps();
938+
}
939+
}
940+
}
941+
868942
template <int nrc_y>
869943
static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
870944
GGML_ASSERT(nrc_x%4 == 0);
@@ -1533,23 +1607,79 @@ static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const Da
15331607
}
15341608
#endif
15351609

1610+
void iqk_convert_iq1_s_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1611+
GGML_ASSERT(n%QK_K == 0);
1612+
GGML_ASSERT(nrc_x%8 == 0);
1613+
1614+
int nb = n/QK_K;
1615+
1616+
const block_iq1_s * x8[8];
1617+
1618+
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
1619+
1620+
ggml_half dh[8];
1621+
uint16_t all_ls[64];
1622+
1623+
uint32_t block[8];
1624+
1625+
for (int ix = 0; ix < nrc_x; ix += 8) {
1626+
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq1_s *)((const char *)vx + (ix + k)*bx);
1627+
for (int i = 0; i < nb; ++i) {
1628+
for (int k = 0; k < 8; ++k) {
1629+
dh[k] = x8[k][i].d;
1630+
auto qs = x8[k][i].qs;
1631+
auto qh = x8[k][i].qh;
1632+
__m256i value;
1633+
for (int ib32 = 0; ib32 < 8; ++ib32) {
1634+
all_ls[8*ib32 + k] = (2*((qh[ib32] >> 12) & 7) + 1);
1635+
value = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib32] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib32] << 2) & 0x700)],
1636+
iq1s_grid[qs[1] | ((qh[ib32] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib32] << 8) & 0x700)]);
1637+
value = _mm256_slli_epi16(_mm256_add_epi8(value, _mm256_set1_epi8(1)), 3);
1638+
int8_t delta = qh[ib32] & 0x8000 ? -9 : -7;
1639+
value = _mm256_add_epi8(value, _mm256_set1_epi8(delta));
1640+
_mm256_storeu_si256((__m256i *)block, value);
1641+
auto q8 = (uint32_t *)y[ib32].qs;
1642+
for (int l = 0; l < 4; ++l) {
1643+
q8[8*l + k + 0] = block[l + 0];
1644+
q8[8*l + k + 32] = block[l + 4];
1645+
}
1646+
qs += 4;
1647+
}
1648+
}
1649+
auto vd = _mm256_mul_ps(_mm256_set1_ps(0.125f), _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh)));
1650+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
1651+
auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
1652+
auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
1653+
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
1654+
_mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
1655+
}
1656+
y += QK_K/32;
1657+
}
1658+
}
1659+
}
15361660

15371661
} // namespace
15381662

15391663
bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& funcs, mul_mat_t& func16) {
15401664

15411665
auto expected_typeB = GGML_TYPE_Q8_K128;
1666+
auto actual_typeB = ggml_type(typeB);
15421667

15431668
func16 = nullptr;
15441669

15451670
switch (typeA) {
15461671
case GGML_TYPE_IQ1_S:
15471672
if (ne00%QK_K != 0) return false;
1548-
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_s_q8_K, funcs);
1673+
if (actual_typeB == GGML_TYPE_Q8_2_X4) {
1674+
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_s_q8_2_x4, funcs);
1675+
expected_typeB = GGML_TYPE_Q8_2_X4;
1676+
} else {
1677+
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_s_q8_K, funcs);
15491678
#ifdef HAVE_FANCY_SIMD
1550-
func16 = mul_mat_iq1_s_q8_K<16>;
1679+
func16 = mul_mat_iq1_s_q8_K<16>;
15511680
#endif
1552-
expected_typeB = GGML_TYPE_Q8_K;
1681+
expected_typeB = GGML_TYPE_Q8_K;
1682+
}
15531683
break;
15541684
case GGML_TYPE_IQ1_S_R4:
15551685
if (ne00%128 != 0) return false;
@@ -1585,8 +1715,17 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
15851715
return false;
15861716
}
15871717

1588-
return ggml_type(typeB) == expected_typeB;
1718+
return actual_typeB == expected_typeB;
1719+
1720+
}
15891721

1722+
bool iqk_convert_1bit_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1723+
if (n%QK_K != 0 || nrc_x%8 != 0) return false;
1724+
switch (ggml_type(type)) {
1725+
case GGML_TYPE_IQ1_S: iqk_convert_iq1_s_q8_0_r8(n, vx, bx, vy, nrc_x); break;
1726+
default: return false;
1727+
}
1728+
return true;
15901729
}
15911730

15921731
#else
@@ -2277,6 +2416,10 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
22772416

22782417
}
22792418

2419+
bool iqk_convert_1bit_q80_r8([[maybe_unused]] int type, [[maybe_unused]] int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, [[maybe_unused]] int nrc_x) {
2420+
return false;
2421+
}
2422+
22802423
#endif
22812424

22822425
#endif

ggml/src/iqk/iqk_gemm_1bit.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@
88

99
bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16);
1010

11+
bool iqk_convert_1bit_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x);
12+
1113
#endif

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,12 @@ struct MulMat {
236236
static inline ggml_type is_dequant_better(ggml_type type, int nrc_y) {
237237
#ifdef __AVX2__
238238
switch (type) {
239-
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
240-
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
241-
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
239+
case GGML_TYPE_IQ2_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type;
240+
case GGML_TYPE_IQ3_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type;
241+
case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type;
242242
case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
243243
case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
244+
case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
244245
default: break;
245246
}
246247
#else
@@ -397,13 +398,13 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
397398
//case GGML_TYPE_Q8_0_R8:
398399
//case GGML_TYPE_IQ4_NL_R4:
399400
// return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16);
400-
//case GGML_TYPE_IQ1_S:
401+
case GGML_TYPE_IQ1_S:
401402
//case GGML_TYPE_IQ1_S_R4:
402403
//case GGML_TYPE_IQ1_M_R4:
403404
//case GGML_TYPE_IQ1_BN:
404405
//case GGML_TYPE_IQ2_BN:
405406
//case GGML_TYPE_IQ2_BN_R4:
406-
// return iqk_set_kernels_1bit(ne00, typeA, typeB, mm.funcs, mm.func16);
407+
return iqk_convert_1bit_q80_r8(typeA, n, vx, bx, vy, nrc_x);
407408

408409
default:
409410
return false;

0 commit comments

Comments
 (0)