Skip to content

Commit 5fb1bb9

Browse files
committed
ggml-cpu: expand to 2 blocks per loop
Signed-off-by: Aaron Teo <[email protected]>
1 parent f7e7539 commit 5fb1bb9

File tree

1 file changed

+41
-3
lines changed

1 file changed

+41
-3
lines changed

ggml/src/ggml-cpu/arch/s390/quants.c

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,49 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
284284

285285
float32x4_t v_acc = vec_splats(0.0f);
286286

287+
for (; ib + 1 < nb; ib += 2) {
288+
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
289+
const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1];
290+
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
291+
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
292+
293+
const uint8x16_t v_x0 = vec_xl(0, x0->qs);
294+
const uint8x16_t v_x1 = vec_xl(0, x1->qs);
295+
296+
int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
297+
int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
298+
int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
299+
int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
300+
301+
v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);
302+
v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);
303+
v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);
304+
v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);
305+
306+
const int8x16_t v_y0l = vec_xl(0, y0->qs);
307+
const int8x16_t v_y0h = vec_xl(QK8_0/2, y0->qs);
308+
const int8x16_t v_y1l = vec_xl(0, y1->qs);
309+
const int8x16_t v_y1h = vec_xl(QK8_0/2, y1->qs);
310+
311+
const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0l), v_x0h, v_y0h);
312+
const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y1l), v_x1h, v_y1h);
313+
314+
const float32x4_t v_xy0f = vec_float(v_xy0);
315+
const float32x4_t v_xy1f = vec_float(v_xy1);
316+
317+
const float32x4_t v_d0 = vec_splats(GGML_E8M0_TO_FP32_HALF(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d));
318+
const float32x4_t v_d1 = vec_splats(GGML_E8M0_TO_FP32_HALF(x1->e) * GGML_CPU_FP16_TO_FP32(y1->d));
319+
320+
v_acc = vec_madd(v_xy0f, v_d0, v_acc);
321+
v_acc = vec_madd(v_xy1f, v_d1, v_acc);
322+
}
323+
287324
for (; ib < nb; ++ib) {
288-
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib];
289-
const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
325+
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
326+
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
290327

291328
const uint8x16_t v_x = vec_xl(0, x0->qs);
329+
292330
int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
293331
int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
294332

@@ -300,8 +338,8 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
300338

301339
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
302340
const float32x4_t v_xyf = vec_float(v_xy);
303-
const float32x4_t v_d = vec_splats(GGML_E8M0_TO_FP32_HALF(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d));
304341

342+
const float32x4_t v_d = vec_splats(GGML_E8M0_TO_FP32_HALF(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d));
305343
v_acc = vec_madd(v_xyf, v_d, v_acc);
306344
}
307345

0 commit comments

Comments
 (0)