Skip to content

Commit 3538930

Browse files
committed
ggml-cpu: rework mxfp4
Signed-off-by: Aaron Teo <[email protected]>
1 parent 377d0fc commit 3538930

File tree

1 file changed

+5
-37
lines changed

1 file changed

+5
-37
lines changed

ggml/src/ggml-cpu/arch/s390/quants.c

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -280,46 +280,13 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
280280

281281
#if defined(__VXE__) || defined(__VXE2__)
282282
const int8x16_t v_k = vec_xl(0, kvalues_mxfp4);
283-
const uint8x16_t v_m = vec_splats((uint8_t)0x0F);
284-
285-
for (; ib + 1 < nb; ib += 2) {
286-
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
287-
const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1];
288-
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
289-
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
290-
291-
const uint8x16_t v_x0 = vec_xl(0, x0->qs);
292-
const uint8x16_t v_x1 = vec_xl(0, x1->qs);
293-
294-
int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
295-
int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
296-
int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
297-
int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
298-
299-
v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);
300-
v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);
301-
v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);
302-
v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);
303-
304-
const int8x16_t v_y0l = vec_xl(0, y0->qs);
305-
const int8x16_t v_y0h = vec_xl(QK8_0/2, y0->qs);
306-
const int8x16_t v_y1l = vec_xl(0, y1->qs);
307-
const int8x16_t v_y1h = vec_xl(QK8_0/2, y1->qs);
308-
309-
const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0l), v_x0h, v_y0h);
310-
const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y1l), v_x1h, v_y1h);
311-
312-
sumf +=
313-
GGML_E8M0_TO_FP32(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d) * vec_hsum_i32x4(v_xy0) +
314-
GGML_E8M0_TO_FP32(x1->e) * GGML_CPU_FP16_TO_FP32(y1->d) * vec_hsum_i32x4(v_xy1);
315-
}
283+
const uint8x16_t v_m = vec_splats((const uint8_t)0x0F);
316284

317285
for (; ib < nb; ++ib) {
318-
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
319-
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
286+
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib];
287+
const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
320288

321289
const uint8x16_t v_x = vec_xl(0, x0->qs);
322-
323290
int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
324291
int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
325292

@@ -331,7 +298,8 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
331298

332299
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
333300

334-
sumf += GGML_E8M0_TO_FP32(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d) * vec_hsum_i32x4(v_xy);
301+
const float scale = GGML_E8M0_TO_FP32(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d);
302+
sumf += scale * vec_hsum_i32x4(v_xy);
335303
}
336304

337305
*s = sumf;

0 commit comments

Comments
 (0)