@@ -280,46 +280,13 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
280280
281281#if defined(__VXE__ ) || defined(__VXE2__ )
282282 const int8x16_t v_k = vec_xl (0 , kvalues_mxfp4 );
283- const uint8x16_t v_m = vec_splats ((uint8_t )0x0F );
284-
285- for (; ib + 1 < nb ; ib += 2 ) {
286- const block_mxfp4 * GGML_RESTRICT x0 = & x [ib + 0 ];
287- const block_mxfp4 * GGML_RESTRICT x1 = & x [ib + 1 ];
288- const block_q8_0 * GGML_RESTRICT y0 = & y [ib + 0 ];
289- const block_q8_0 * GGML_RESTRICT y1 = & y [ib + 1 ];
290-
291- const uint8x16_t v_x0 = vec_xl (0 , x0 -> qs );
292- const uint8x16_t v_x1 = vec_xl (0 , x1 -> qs );
293-
294- int8x16_t v_x0l = (int8x16_t )vec_and (v_x0 , v_m );
295- int8x16_t v_x0h = (int8x16_t )vec_sr (v_x0 , 4 );
296- int8x16_t v_x1l = (int8x16_t )vec_and (v_x1 , v_m );
297- int8x16_t v_x1h = (int8x16_t )vec_sr (v_x1 , 4 );
298-
299- v_x0l = vec_perm (v_k , v_k , (uchar8x16_t )v_x0l );
300- v_x0h = vec_perm (v_k , v_k , (uchar8x16_t )v_x0h );
301- v_x1l = vec_perm (v_k , v_k , (uchar8x16_t )v_x1l );
302- v_x1h = vec_perm (v_k , v_k , (uchar8x16_t )v_x1h );
303-
304- const int8x16_t v_y0l = vec_xl (0 , y0 -> qs );
305- const int8x16_t v_y0h = vec_xl (QK8_0 /2 , y0 -> qs );
306- const int8x16_t v_y1l = vec_xl (0 , y1 -> qs );
307- const int8x16_t v_y1h = vec_xl (QK8_0 /2 , y1 -> qs );
308-
309- const int32x4_t v_xy0 = ggml_vec_dot (ggml_vec_dot (vec_splats (0 ), v_x0l , v_y0l ), v_x0h , v_y0h );
310- const int32x4_t v_xy1 = ggml_vec_dot (ggml_vec_dot (vec_splats (0 ), v_x1l , v_y1l ), v_x1h , v_y1h );
311-
312- sumf +=
313- GGML_E8M0_TO_FP32 (x0 -> e ) * GGML_CPU_FP16_TO_FP32 (y0 -> d ) * vec_hsum_i32x4 (v_xy0 ) +
314- GGML_E8M0_TO_FP32 (x1 -> e ) * GGML_CPU_FP16_TO_FP32 (y1 -> d ) * vec_hsum_i32x4 (v_xy1 );
315- }
283+ const uint8x16_t v_m = vec_splats ((const uint8_t )0x0F );
316284
317285 for (; ib < nb ; ++ ib ) {
318- const block_mxfp4 * GGML_RESTRICT x0 = & x [ib + 0 ];
319- const block_q8_0 * GGML_RESTRICT y0 = & y [ib + 0 ];
286+ const block_mxfp4 * GGML_RESTRICT x0 = & x [ib ];
287+ const block_q8_0 * GGML_RESTRICT y0 = & y [ib ];
320288
321289 const uint8x16_t v_x = vec_xl (0 , x0 -> qs );
322-
323290 int8x16_t v_xl = (int8x16_t )vec_and (v_x , v_m );
324291 int8x16_t v_xh = (int8x16_t )vec_sr (v_x , 4 );
325292
@@ -331,7 +298,8 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
331298
332299 const int32x4_t v_xy = ggml_vec_dot (ggml_vec_dot (vec_splats (0 ), v_xl , v_yl ), v_xh , v_yh );
333300
334- sumf += GGML_E8M0_TO_FP32 (x0 -> e ) * GGML_CPU_FP16_TO_FP32 (y0 -> d ) * vec_hsum_i32x4 (v_xy );
301+ const float scale = GGML_E8M0_TO_FP32 (x0 -> e ) * GGML_CPU_FP16_TO_FP32 (y0 -> d );
302+ sumf += scale * vec_hsum_i32x4 (v_xy );
335303 }
336304
337305 * s = sumf ;
0 commit comments