Skip to content

Commit 4fc3cb4

Browse files
ikawrakowIwan Kawrakow
andauthored
iq3_s: much faster GEMM via repacking to q8_0_r8 (#518)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 3f54b49 commit 4fc3cb4

File tree

3 files changed

+83
-3
lines changed

3 files changed

+83
-3
lines changed

ggml/src/ggml.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,7 +1153,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
11531153
.from_float = quantize_row_iq3_s,
11541154
.from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref,
11551155
.vec_dot = ggml_vec_dot_iq3_s_q8_K,
1156+
#ifdef __AVX2__
1157+
.vec_dot_type = GGML_TYPE_Q8_2_X4,
1158+
#else
11561159
.vec_dot_type = GGML_TYPE_Q8_K,
1160+
#endif
11571161
.nrows = 1,
11581162
.row_meta_size = 0,
11591163
},

ggml/src/iqk/iqk_gemm_iquants.cpp

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,10 @@ struct SignHelper {
115115
return _mm256_sign_epi8(value, make_signs(sign_bits[0] | (sign_bits[1] << 16)));
116116
#endif
117117
}
118-
inline void sign_4_values(const uint16_t * sign_bits, __m256i * values) const {
119-
#ifdef HAVE_FANCY_SIMD
118+
IQK_ALWAYS_INLINE void sign_4_values(const uint16_t * sign_bits, __m256i * values) const {
119+
// Somehow the FANCY_SIMD version has become 50% slower for TG???
120+
#ifdef z_HAVE_FANCY_SIMD
121+
//__mmask32 mask[4]; std::memcpy(mask, sign_bits, 4*sizeof(__mmask32));
120122
const __mmask32 * mask = (const __mmask32 *)sign_bits;
121123
values[0] = _mm256_mask_sub_epi8(values[0], mask[0], _mm256_setzero_si256(), values[0]);
122124
values[1] = _mm256_mask_sub_epi8(values[1], mask[1], _mm256_setzero_si256(), values[1]);
@@ -534,7 +536,7 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
534536

535537
};
536538

537-
#ifdef HAVE_FANCY_SIMD
539+
#ifdef z_HAVE_FANCY_SIMD
538540
// Strangely enough, the following implementation makes PP ~6% slower and TG ~6% faster
539541
// compared to the vanilla AVX2 version below.
540542
struct IndexHelperIQ3S {
@@ -597,6 +599,15 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
597599
auto scales16 = make_scales(i, d);
598600
scales[0] = MM256_SET_M128I(scales16, scales16);
599601
}
602+
inline void new_block_f(int i, __m256 * scales) {
603+
auto sc16 = make_scales(i, d);
604+
auto scf = _mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(sc16)));
605+
auto scf_l = _mm256_castps256_ps128(scf);
606+
auto scf_h = _mm256_extractf128_ps(scf, 1);
607+
scales[0] = _mm256_set_m128(scf_l, scf_l);
608+
scales[1] = _mm256_set_m128(scf_h, scf_h);
609+
scales[2] = _mm256_mul_ps(scf, _mm256_set1_ps(-minv));
610+
}
600611
inline float new_block(int i, __m256i * scales, __m256i& mins) {
601612
auto scales16 = make_scales(i, d);
602613
mins = scb.shuffle(scales16);
@@ -1832,6 +1843,60 @@ void iqk_convert_iq3_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, i
18321843
}
18331844
}
18341845

1846+
void iqk_convert_iq3_s_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
1847+
GGML_ASSERT(n%QK_K == 0);
1848+
GGML_ASSERT(nrc_x%8 == 0);
1849+
1850+
int nb = n/QK_K;
1851+
1852+
const block_iq3_s * x8[8];
1853+
1854+
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
1855+
1856+
ggml_half dh[8];
1857+
uint16_t all_ls[64];
1858+
SignHelper sh;
1859+
IndexHelperIQ3S helper;
1860+
1861+
uint32_t block[8];
1862+
__m256i values[8];
1863+
1864+
for (int ix = 0; ix < nrc_x; ix += 8) {
1865+
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq3_s *)((const char *)vx + (ix + k)*bx);
1866+
for (int i = 0; i < nb; ++i) {
1867+
for (int k = 0; k < 8; ++k) {
1868+
dh[k] = x8[k][i].d;
1869+
auto qs = x8[k][i].qs;
1870+
auto qh = x8[k][i].qh;
1871+
auto signs = (const uint16_t *)x8[k][i].signs;
1872+
helper.make2(qs+ 0, qh+0, values+0);
1873+
helper.make2(qs+16, qh+2, values+2);
1874+
sh.sign_4_values(signs+0, values+0);
1875+
helper.make2(qs+32, qh+4, values+4);
1876+
helper.make2(qs+48, qh+6, values+6);
1877+
sh.sign_4_values(signs+8, values+4);
1878+
for (int ib32 = 0; ib32 < 8; ++ib32) {
1879+
all_ls[8*ib32 + k] = (2*((x8[k][i].scales[ib32/2] >> 4*(ib32%2)) & 0xf) + 1);
1880+
_mm256_storeu_si256((__m256i *)block, values[ib32]);
1881+
auto q8 = (uint32_t *)y[ib32].qs;
1882+
for (int l = 0; l < 4; ++l) {
1883+
q8[8*l + k + 0] = block[l + 0];
1884+
q8[8*l + k + 32] = block[l + 4];
1885+
}
1886+
}
1887+
}
1888+
auto vd = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh));
1889+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
1890+
auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
1891+
auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
1892+
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
1893+
_mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
1894+
}
1895+
y += QK_K/32;
1896+
}
1897+
}
1898+
}
1899+
18351900
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
18361901
funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;
18371902
funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;
@@ -1867,6 +1932,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
18671932
return false;
18681933
}
18691934

1935+
if (ggml_type(typeA) == GGML_TYPE_IQ3_S) {
1936+
if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
1937+
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels);
1938+
func16 = nullptr;
1939+
return true;
1940+
}
1941+
return false;
1942+
}
1943+
18701944
if (ggml_type(typeB) != GGML_TYPE_Q8_K) {
18711945
return false;
18721946
}
@@ -1926,6 +2000,7 @@ bool iqk_convert_iquants_q80_r8(int type, int n, const void * vx, size_t bx, voi
19262000
switch (ggml_type(type)) {
19272001
case GGML_TYPE_IQ2_XXS: iqk_convert_iq2_xxs_q8_0_r8(n, vx, bx, vy, nrc_x); break;
19282002
case GGML_TYPE_IQ3_XXS: iqk_convert_iq3_xxs_q8_0_r8(n, vx, bx, vy, nrc_x); break;
2003+
case GGML_TYPE_IQ3_S : iqk_convert_iq3_s_q8_0_r8 (n, vx, bx, vy, nrc_x); break;
19292004
default: return false;
19302005
}
19312006
return true;

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ struct MulMat {
241241
case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type;
242242
case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
243243
case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
244+
case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
244245
case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
245246
default: break;
246247
}

0 commit comments

Comments
 (0)