Skip to content

Commit 69af3f5

Browse files
ikawrakowIwan Kawrakow
andauthored
Much faster iq3_xxs GEMM via repacking to q8_0_r8 (AVX2) (#516)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent e56061f commit 69af3f5

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

ggml/src/ggml.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,7 +1123,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
11231123
.from_float = quantize_row_iq3_xxs,
11241124
.from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref,
11251125
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
1126+
#ifdef __AVX2__
1127+
.vec_dot_type = GGML_TYPE_Q8_2_X4,
1128+
#else
11261129
.vec_dot_type = GGML_TYPE_Q8_K,
1130+
#endif
11271131
.nrows = 1,
11281132
.row_meta_size = 0,
11291133
},

ggml/src/iqk/iqk_gemm_iquants.cpp

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,15 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
472472
auto scales16 = prepare_scales(i);
473473
scales[0] = MM256_SET_M128I(scales16, scales16);
474474
}
475+
inline void new_block_f(int i, __m256 * scales) {
476+
auto sc16 = prepare_scales(i);
477+
auto scf = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(sc16)));
478+
auto scf_l = _mm256_castps256_ps128(scf);
479+
auto scf_h = _mm256_extractf128_ps(scf, 1);
480+
scales[0] = _mm256_set_m128(scf_l, scf_l);
481+
scales[1] = _mm256_set_m128(scf_h, scf_h);
482+
scales[2] = _mm256_mul_ps(scf, _mm256_set1_ps(-minv));
483+
}
475484
inline float new_block(int i, __m256i * scales, __m256i& mins) {
476485
auto scales16 = prepare_scales(i);
477486
mins = scb.shuffle(scales16);
@@ -1771,6 +1780,58 @@ void iqk_convert_iq2_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, i
17711780
}
17721781
}
17731782

1783+
void iqk_convert_iq3_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1784+
GGML_ASSERT(n%QK_K == 0);
1785+
GGML_ASSERT(nrc_x%8 == 0);
1786+
1787+
int nb = n/QK_K;
1788+
1789+
const block_iq3_xxs * x8[8];
1790+
1791+
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
1792+
1793+
ggml_half dh[8];
1794+
uint16_t all_ls[64];
1795+
EvenSignHelper esh;
1796+
1797+
uint32_t block[8];
1798+
uint32_t aux32;
1799+
1800+
for (int ix = 0; ix < nrc_x; ix += 8) {
1801+
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq3_xxs *)((const char *)vx + (ix + k)*bx);
1802+
for (int i = 0; i < nb; ++i) {
1803+
// TODO: simdify
1804+
for (int k = 0; k < 8; ++k) {
1805+
dh[k] = x8[k][i].d;
1806+
auto qs = x8[k][i].qs;
1807+
auto sas = qs + QK_K/4;
1808+
for (int ib32 = 0; ib32 < 8; ++ib32) {
1809+
std::memcpy(&aux32, sas + 4*ib32, sizeof(uint32_t));
1810+
all_ls[8*ib32 + k] = (2*(aux32 >> 28) + 1);
1811+
auto value = _mm256_set_epi32(iq3xxs_grid[qs[7]], iq3xxs_grid[qs[6]], iq3xxs_grid[qs[5]], iq3xxs_grid[qs[4]],
1812+
iq3xxs_grid[qs[3]], iq3xxs_grid[qs[2]], iq3xxs_grid[qs[1]], iq3xxs_grid[qs[0]]);
1813+
esh.sign_value(aux32, value);
1814+
_mm256_storeu_si256((__m256i *)block, value);
1815+
auto q8 = (uint32_t *)y[ib32].qs;
1816+
for (int l = 0; l < 4; ++l) {
1817+
q8[8*l + k + 0] = block[l + 0];
1818+
q8[8*l + k + 32] = block[l + 4];
1819+
}
1820+
qs += 8;
1821+
}
1822+
}
1823+
auto vd = _mm256_mul_ps(_mm256_set1_ps(0.25f), _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh)));
1824+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
1825+
auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
1826+
auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
1827+
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
1828+
_mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
1829+
}
1830+
y += QK_K/32;
1831+
}
1832+
}
1833+
}
1834+
17741835
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
17751836
funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
17761837
funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
@@ -1791,7 +1852,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
17911852
if (ggml_type(typeA) == GGML_TYPE_IQ2_XXS) {
17921853
if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
17931854
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ2XXS, kernels);
1794-
//IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_q8_2_IQ_N, kernels);
1855+
func16 = nullptr;
1856+
return true;
1857+
}
1858+
return false;
1859+
}
1860+
1861+
if (ggml_type(typeA) == GGML_TYPE_IQ3_XXS) {
1862+
if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
1863+
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3XXS, kernels);
17951864
func16 = nullptr;
17961865
return true;
17971866
}
@@ -1856,6 +1925,7 @@ bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, voi
18561925
if (n%QK_K != 0 || nrc_x%8 != 0) return false;
18571926
switch (ggml_type(type)) {
18581927
case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_0_r8(n, vx, bx, vy, nrc_x); break;
1928+
case GGML_TYPE_IQ3_XXS: iqk_convert_iq3_xxs_q8_0_r8(n, vx, bx, vy, nrc_x); break;
18591929
default: return false;
18601930
}
18611931
return true;

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ struct MulMat {
240240
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
241241
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F32 : type;
242242
case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
243+
case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
243244
default: break;
244245
}
245246
#else

0 commit comments

Comments
 (0)