Skip to content
1 change: 0 additions & 1 deletion ggml/src/ggml-cpu/arch-fallback.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
Expand Down
96 changes: 96 additions & 0 deletions ggml/src/ggml-cpu/arch/s390/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,102 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
#endif
}

void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
assert(n % QK_MXFP4 == 0);
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");

const int qk = QK_MXFP4;
const int nb = n / qk;

const block_mxfp4 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;

int ib = 0;
float sumf = 0.0f;

#if defined(__VXE__) || defined(__VXE2__)
const int8x16_t v_k = vec_xl(0, kvalues_mxfp4);
const uint8x16_t v_m = vec_splats((const uint8_t)0x0F);

float32x4_t v_acc = vec_splats(0.0f);

#pragma GCC unroll 8
for (; ib + 1 < nb; ib += 2) {
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1];
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];

const uint8x16_t v_x0 = vec_xl(0, x0->qs);
const uint8x16_t v_x1 = vec_xl(0, x1->qs);

int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);

v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);
v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);
v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);
v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);

const int8x16_t v_y0l = vec_xl(0, y0->qs);
const int8x16_t v_y0h = vec_xl(QK8_0/2, y0->qs);
const int8x16_t v_y1l = vec_xl(0, y1->qs);
const int8x16_t v_y1h = vec_xl(QK8_0/2, y1->qs);

const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0l), v_x0h, v_y0h);
const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y1l), v_x1h, v_y1h);

const float32x4_t v_xy0f = vec_float(v_xy0);
const float32x4_t v_xy1f = vec_float(v_xy1);

const float32x4_t v_d0 = vec_splats(GGML_E8M0_TO_FP32_HALF(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d));
const float32x4_t v_d1 = vec_splats(GGML_E8M0_TO_FP32_HALF(x1->e) * GGML_CPU_FP16_TO_FP32(y1->d));

v_acc = vec_madd(v_xy0f, v_d0, v_acc);
v_acc = vec_madd(v_xy1f, v_d1, v_acc);
}

#pragma GCC unroll 8
for (; ib < nb; ++ib) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unroll seems unnecessary, since this loop should only have zero or one iterations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed in latest commit

const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];

const uint8x16_t v_x = vec_xl(0, x0->qs);

int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);

v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl);
v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh);

const int8x16_t v_yl = vec_xl(0, y0->qs);
const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);

const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
const float32x4_t v_xyf = vec_float(v_xy);

const float32x4_t v_d = vec_splats(GGML_E8M0_TO_FP32_HALF(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d));
v_acc = vec_madd(v_xyf, v_d, v_acc);
}

sumf = vec_hsum_f32x4(v_acc);
*s = sumf;
#else
UNUSED(x);
UNUSED(y);
UNUSED(ib);
UNUSED(sumf);
ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}

void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
Expand Down
Loading