@@ -284,11 +284,49 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
284284
285285 float32x4_t v_acc = vec_splats (0.0f );
286286
287+ for (; ib + 1 < nb ; ib += 2 ) {
288+ const block_mxfp4 * GGML_RESTRICT x0 = & x [ib + 0 ];
289+ const block_mxfp4 * GGML_RESTRICT x1 = & x [ib + 1 ];
290+ const block_q8_0 * GGML_RESTRICT y0 = & y [ib + 0 ];
291+ const block_q8_0 * GGML_RESTRICT y1 = & y [ib + 1 ];
292+
293+ const uint8x16_t v_x0 = vec_xl (0 , x0 -> qs );
294+ const uint8x16_t v_x1 = vec_xl (0 , x1 -> qs );
295+
296+ int8x16_t v_x0l = (int8x16_t )vec_and (v_x0 , v_m );
297+ int8x16_t v_x0h = (int8x16_t )vec_sr (v_x0 , 4 );
298+ int8x16_t v_x1l = (int8x16_t )vec_and (v_x1 , v_m );
299+ int8x16_t v_x1h = (int8x16_t )vec_sr (v_x1 , 4 );
300+
301+ v_x0l = vec_perm (v_k , v_k , (uchar8x16_t )v_x0l );
302+ v_x0h = vec_perm (v_k , v_k , (uchar8x16_t )v_x0h );
303+ v_x1l = vec_perm (v_k , v_k , (uchar8x16_t )v_x1l );
304+ v_x1h = vec_perm (v_k , v_k , (uchar8x16_t )v_x1h );
305+
306+ const int8x16_t v_y0l = vec_xl (0 , y0 -> qs );
307+ const int8x16_t v_y0h = vec_xl (QK8_0 /2 , y0 -> qs );
308+ const int8x16_t v_y1l = vec_xl (0 , y1 -> qs );
309+ const int8x16_t v_y1h = vec_xl (QK8_0 /2 , y1 -> qs );
310+
311+ const int32x4_t v_xy0 = ggml_vec_dot (ggml_vec_dot (vec_splats (0 ), v_x0l , v_y0l ), v_x0h , v_y0h );
312+ const int32x4_t v_xy1 = ggml_vec_dot (ggml_vec_dot (vec_splats (0 ), v_x1l , v_y1l ), v_x1h , v_y1h );
313+
314+ const float32x4_t v_xy0f = vec_float (v_xy0 );
315+ const float32x4_t v_xy1f = vec_float (v_xy1 );
316+
317+ const float32x4_t v_d0 = vec_splats (GGML_E8M0_TO_FP32_HALF (x0 -> e ) * GGML_CPU_FP16_TO_FP32 (y0 -> d ));
318+ const float32x4_t v_d1 = vec_splats (GGML_E8M0_TO_FP32_HALF (x1 -> e ) * GGML_CPU_FP16_TO_FP32 (y1 -> d ));
319+
320+ v_acc = vec_madd (v_xy0f , v_d0 , v_acc );
321+ v_acc = vec_madd (v_xy1f , v_d1 , v_acc );
322+ }
323+
287324 for (; ib < nb ; ++ ib ) {
288- const block_mxfp4 * GGML_RESTRICT x0 = & x [ib ];
289- const block_q8_0 * GGML_RESTRICT y0 = & y [ib ];
325+ const block_mxfp4 * GGML_RESTRICT x0 = & x [ib + 0 ];
326+ const block_q8_0 * GGML_RESTRICT y0 = & y [ib + 0 ];
290327
291328 const uint8x16_t v_x = vec_xl (0 , x0 -> qs );
329+
292330 int8x16_t v_xl = (int8x16_t )vec_and (v_x , v_m );
293331 int8x16_t v_xh = (int8x16_t )vec_sr (v_x , 4 );
294332
@@ -300,8 +338,8 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
300338
301339 const int32x4_t v_xy = ggml_vec_dot (ggml_vec_dot (vec_splats (0 ), v_xl , v_yl ), v_xh , v_yh );
302340 const float32x4_t v_xyf = vec_float (v_xy );
303- const float32x4_t v_d = vec_splats (GGML_E8M0_TO_FP32_HALF (x0 -> e ) * GGML_CPU_FP16_TO_FP32 (y0 -> d ));
304341
342+ const float32x4_t v_d = vec_splats (GGML_E8M0_TO_FP32_HALF (x0 -> e ) * GGML_CPU_FP16_TO_FP32 (y0 -> d ));
305343 v_acc = vec_madd (v_xyf , v_d , v_acc );
306344 }
307345
0 commit comments