Skip to content

Commit e0570ca

Browse files
committed
IKL PR 536 and 540
Credit : @louiehelm for 536, @ikawrakow for 540.
1 parent a8753a6 commit e0570ca

File tree

2 files changed

+23
-29
lines changed

2 files changed

+23
-29
lines changed

ggml/src/iqk/iqk_gemm_ktquants.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
namespace {
1515

1616
inline uint32_t trellis_next(uint32_t& val) {
17-
constexpr uint32_t ka = 3417055213;
18-
constexpr uint32_t kb = 0;
17+
constexpr uint32_t ka = 89226354;
18+
constexpr uint32_t kb = 64248484;
1919
constexpr uint32_t kmask = 0x8fff8fff;
2020
constexpr uint32_t km32 = 0x3b603b60;
21-
val = val*ka;
21+
val = val*ka + kb;
2222
return (val & kmask) ^ km32;
2323
}
2424

@@ -31,8 +31,8 @@ inline float trellis_gen(uint32_t& val, uint32_t* s) {
3131
struct Trellis1 {
3232
constexpr static uint32_t kmask = 0x8fff8fff;
3333
constexpr static uint32_t km32 = 0x3b603b60;
34-
constexpr static uint32_t ka = 3417055213;
35-
constexpr static uint32_t kb = 0;
34+
constexpr static uint32_t ka = 89226354;
35+
constexpr static uint32_t kb = 64248484;
3636
constexpr static uint32_t ka1 = ka*ka;
3737
constexpr static uint32_t kb1 = kb*ka+kb;
3838
constexpr static uint32_t ka2 = ka1*ka;
@@ -76,8 +76,8 @@ inline __m256 trellis_gen8(__m256i i8) {
7676
struct Trellis2 {
7777
constexpr static uint32_t kmask = 0x8fff8fff;
7878
constexpr static uint32_t km32 = 0x3b603b60;
79-
constexpr static uint32_t ka = 3417055213;
80-
constexpr static uint32_t kb = 0;
79+
constexpr static uint32_t ka = 89226354;
80+
constexpr static uint32_t kb = 64248484;
8181
constexpr static uint32_t ka1 = ka*ka;
8282
constexpr static uint32_t kb1 = kb*ka+kb;
8383
constexpr static uint32_t ka2 = ka1*ka;
@@ -1080,8 +1080,8 @@ namespace {
10801080
struct Trellis1 {
10811081
constexpr static uint32_t kmask = 0x8fff8fff;
10821082
constexpr static uint32_t km32 = 0x3b603b60;
1083-
constexpr static uint32_t ka = 3417055213;
1084-
constexpr static uint32_t kb = 0;
1083+
constexpr static uint32_t ka = 89226354;
1084+
constexpr static uint32_t kb = 64248484;
10851085
constexpr static uint32_t ka1 = ka*ka;
10861086
constexpr static uint32_t kb1 = kb*ka+kb;
10871087
constexpr static uint32_t ka2 = ka1*ka;
@@ -1586,7 +1586,7 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
15861586
}
15871587

15881588
struct Trellis3 {
1589-
constexpr static uint32_t ka = ;0xCBAC1FED;
1589+
constexpr static uint32_t ka = 0xCBAC1FED;
15901590
constexpr static uint32_t ka1 = ka*ka;
15911591
constexpr static uint32_t ka2 = ka1*ka;
15921592
constexpr static uint32_t ka3 = ka2*ka;

ggml/src/iqk/iqk_gemm_legacy_quants.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -188,20 +188,14 @@ struct ScaleHelperQ8_2 {
188188
inline __m256 prepare4(__m256 other_scales, const Q * y) {
189189
return _mm256_mul_ps(other_scales, prepare4<Q>(y));
190190
}
191-
template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
192-
float d = GGML_BF16_TO_FP32(y->d);
191+
template <typename Q> static inline std::pair<float, float> prepare1(const Q * y) {
192+
float d = GGML_BF16_TO_FP32(ggml_bf16_t{y->d});
193193
int16_t m = *(const int16_t *)&y->s;
194194
return std::make_pair(d, d*m);
195195
}
196-
template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {
197-
float d = GGML_BF16_TO_FP32(y->d);
198-
int16_t m = *(const int16_t *)&y->s;
199-
return std::make_pair(dm.first*d, dm.second*d*m);
200-
}
201-
std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_2 * y) const {
202-
ggml_bf16_t dy; dy.bits = y->d; int16_t s = *(const int16_t *)&y->s;
203-
float d = GGML_BF16_TO_FP32(dy);
204-
return std::make_pair(dm.first*d, dm.second*d*s);
196+
static inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const block_q8_2 * y) {
197+
auto d = prepare1(y);
198+
return std::make_pair(dm.first*d.first, dm.second*d.second);
205199
}
206200
};
207201

@@ -1484,14 +1478,14 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
14841478
}
14851479
}
14861480
if (4*(nb/4) < nb) {
1487-
auto qy = (const block_q8_1 *)q8.y[0];
1481+
auto qy = (const block_q8_2 *)q8.y[0];
14881482
for (int ib = 4*(nb/4); ib < nb; ++ib) {
14891483
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
14901484
auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx);
1491-
ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s;
1492-
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
1485+
auto [d8, m8] = ScaleHelperQ8_2::prepare1(qy + ib);
1486+
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8));
14931487
acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
1494-
acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[1]);
1488+
acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(m8), acc[1]);
14951489
}
14961490
}
14971491
info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0]));
@@ -1535,12 +1529,12 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
15351529
qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127));
15361530
}
15371531
for (int iy = 0; iy < nrc_y; ++iy) {
1538-
auto qy = (const block_q8_1 *)q8.y[iy];
1532+
auto qy = (const block_q8_2 *)q8.y[iy];
15391533
auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs);
1540-
ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s;
1541-
auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
1534+
auto [d8, m8] = ScaleHelperQ8_2::prepare1(qy + ib);
1535+
auto dy = _mm512_set1_ps(d8);
15421536
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
1543-
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
1537+
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(m8), acc[2*iy+1]);
15441538
}
15451539
}
15461540
for (int iy = 0; iy < nrc_y; ++iy) {

0 commit comments

Comments
 (0)