Skip to content

Commit 618ef46

Browse files
committed
ggml-cpu: impl mxfp4 s390x
Signed-off-by: Aaron Teo <[email protected]>
1 parent 3ecb2f6 commit 618ef46

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

ggml/src/ggml-cpu/arch-fallback.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@
160160
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
161161
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
162162
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
163-
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
164163
// repack.cpp
165164
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
166165
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8

ggml/src/ggml-cpu/arch/s390/quants.c

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,88 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
260260
#endif
261261
}
262262

263+
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
264+
assert(nrc == 1);
265+
UNUSED(nrc);
266+
UNUSED(bx);
267+
UNUSED(by);
268+
UNUSED(bs);
269+
assert(n % QK_MXFP4 == 0);
270+
static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
271+
272+
const int qk = QK_MXFP4;
273+
const int nb = n / qk;
274+
275+
const block_mxfp4 * GGML_RESTRICT x = vx;
276+
const block_q8_0 * GGML_RESTRICT y = vy;
277+
278+
int ib = 0;
279+
float sumf = 0.0f;
280+
281+
#if defined(__VXE__) || defined(__VXE2__)
282+
const uint8x16_t v_k = vec_xl(0, kvalues_mxfp4);
283+
const uint8x16_t v_m = vec_splats((uint8_t)0x0F);
284+
285+
for (; ib + 1 < nb; ib += 2) {
286+
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
287+
const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1];
288+
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
289+
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
290+
291+
const uint8x16_t v_x0 = vec_xl(0, x0->qs);
292+
const uint8x16_t v_x1 = vec_xl(0, x1->qs);
293+
294+
int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
295+
int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
296+
int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
297+
int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
298+
299+
v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);
300+
v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);
301+
v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);
302+
v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);
303+
304+
const int8x16_t v_y0l = vec_xl(0, y0->qs);
305+
const int8x16_t v_y0h = vec_xl(QK8_0/2, y0->qs);
306+
const int8x16_t v_y1l = vec_xl(0, y1->qs);
307+
const int8x16_t v_y1h = vec_xl(QK8_0/2, y1->qs);
308+
309+
const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0l), v_x0h, v_y0h);
310+
const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y1l), v_x1h, v_y1h);
311+
312+
sumf +=
313+
GGML_E8M0_TO_FP32(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d) * vec_hsum_i32x4(v_xy0) +
314+
GGML_E8M0_TO_FP32(x1->e) * GGML_CPU_FP16_TO_FP32(y1->d) * vec_hsum_i32x4(v_xy1);
315+
}
316+
317+
for (; ib < nb; ++ib) {
318+
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
319+
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
320+
321+
const uint8x16_t v_x = vec_xl(0, x0->qs);
322+
323+
int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
324+
int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
325+
326+
v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl);
327+
v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh);
328+
329+
const int8x16_t v_yl = vec_xl(0, y0->qs);
330+
const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);
331+
332+
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
333+
334+
sumf += GGML_E8M0_TO_FP32(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d) * vec_hsum_i32x4(v_xy);
335+
}
336+
#else
337+
UNUSED(x);
338+
UNUSED(y);
339+
UNUSED(ib);
340+
UNUSED(sumf);
341+
ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
342+
#endif
343+
}
344+
263345
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
264346
const int qk = QK8_0;
265347
const int nb = n / qk;

0 commit comments

Comments
 (0)