Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions ggml/src/iqk/iqk_gemm_legacy_quants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,20 +188,14 @@ struct ScaleHelperQ8_2 {
inline __m256 prepare4(__m256 other_scales, const Q * y) {
return _mm256_mul_ps(other_scales, prepare4<Q>(y));
}
template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {
float d = GGML_BF16_TO_FP32(y->d);
template <typename Q> static inline std::pair<float, float> prepare1(const Q * y) {
float d = GGML_BF16_TO_FP32(ggml_bf16_t{y->d});
int16_t m = *(const int16_t *)&y->s;
return std::make_pair(d, d*m);
}
template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {
float d = GGML_BF16_TO_FP32(y->d);
int16_t m = *(const int16_t *)&y->s;
return std::make_pair(dm.first*d, dm.second*d*m);
}
std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_2 * y) const {
ggml_bf16_t dy; dy.bits = y->d; int16_t s = *(const int16_t *)&y->s;
float d = GGML_BF16_TO_FP32(dy);
return std::make_pair(dm.first*d, dm.second*d*s);
static inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const block_q8_2 * y) {
auto d = prepare1(y);
return std::make_pair(dm.first*d.first, dm.second*d.second);
}
};

Expand Down Expand Up @@ -1484,14 +1478,14 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
}
}
if (4*(nb/4) < nb) {
auto qy = (const block_q8_1 *)q8.y[0];
auto qy = (const block_q8_2 *)q8.y[0];
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d));
auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx);
ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s;
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d)));
auto [d8, m8] = ScaleHelperQ8_2::prepare1(qy + ib);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8));
acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]);
acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[1]);
acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(m8), acc[1]);
}
}
info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0]));
Expand Down Expand Up @@ -1535,12 +1529,12 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127));
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_1 *)q8.y[iy];
auto qy = (const block_q8_2 *)q8.y[iy];
auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs);
ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s;
auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d));
auto [d8, m8] = ScaleHelperQ8_2::prepare1(qy + ib);
auto dy = _mm512_set1_ps(d8);
acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]);
acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(m8), acc[2*iy+1]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
Expand Down