@@ -260,6 +260,88 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
260260#endif
261261}
262262
263+ void ggml_vec_dot_mxfp4_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
264+ assert (nrc == 1 );
265+ UNUSED (nrc );
266+ UNUSED (bx );
267+ UNUSED (by );
268+ UNUSED (bs );
269+ assert (n % QK_MXFP4 == 0 );
270+ static_assert (QK_MXFP4 == QK8_0 , "QK_MXFP4 and QK8_0 must be the same" );
271+
272+ const int qk = QK_MXFP4 ;
273+ const int nb = n / qk ;
274+
275+ const block_mxfp4 * GGML_RESTRICT x = vx ;
276+ const block_q8_0 * GGML_RESTRICT y = vy ;
277+
278+ int ib = 0 ;
279+ float sumf = 0.0f ;
280+
281+ #if defined(__VXE__ ) || defined(__VXE2__ )
282+ const uint8x16_t v_k = vec_xl (0 , kvalues_mxfp4 );
283+ const uint8x16_t v_m = vec_splats ((uint8_t )0x0F );
284+
285+ for (; ib + 1 < nb ; ib += 2 ) {
286+ const block_mxfp4 * GGML_RESTRICT x0 = & x [ib + 0 ];
287+ const block_mxfp4 * GGML_RESTRICT x1 = & x [ib + 1 ];
288+ const block_q8_0 * GGML_RESTRICT y0 = & y [ib + 0 ];
289+ const block_q8_0 * GGML_RESTRICT y1 = & y [ib + 1 ];
290+
291+ const uint8x16_t v_x0 = vec_xl (0 , x0 -> qs );
292+ const uint8x16_t v_x1 = vec_xl (0 , x1 -> qs );
293+
294+ int8x16_t v_x0l = (int8x16_t )vec_and (v_x0 , v_m );
295+ int8x16_t v_x0h = (int8x16_t )vec_sr (v_x0 , 4 );
296+ int8x16_t v_x1l = (int8x16_t )vec_and (v_x1 , v_m );
297+ int8x16_t v_x1h = (int8x16_t )vec_sr (v_x1 , 4 );
298+
299+ v_x0l = vec_perm (v_k , v_k , (uchar8x16_t )v_x0l );
300+ v_x0h = vec_perm (v_k , v_k , (uchar8x16_t )v_x0h );
301+ v_x1l = vec_perm (v_k , v_k , (uchar8x16_t )v_x1l );
302+ v_x1h = vec_perm (v_k , v_k , (uchar8x16_t )v_x1h );
303+
304+ const int8x16_t v_y0l = vec_xl (0 , y0 -> qs );
305+ const int8x16_t v_y0h = vec_xl (QK8_0 /2 , y0 -> qs );
306+ const int8x16_t v_y1l = vec_xl (0 , y1 -> qs );
307+ const int8x16_t v_y1h = vec_xl (QK8_0 /2 , y1 -> qs );
308+
309+ const int32x4_t v_xy0 = ggml_vec_dot (ggml_vec_dot (vec_splats (0 ), v_x0l , v_y0l ), v_x0h , v_y0h );
310+ const int32x4_t v_xy1 = ggml_vec_dot (ggml_vec_dot (vec_splats (0 ), v_x1l , v_y1l ), v_x1h , v_y1h );
311+
312+ sumf +=
313+ GGML_E8M0_TO_FP32 (x0 -> e ) * GGML_CPU_FP16_TO_FP32 (y0 -> d ) * vec_hsum_i32x4 (v_xy0 ) +
314+ GGML_E8M0_TO_FP32 (x1 -> e ) * GGML_CPU_FP16_TO_FP32 (y1 -> d ) * vec_hsum_i32x4 (v_xy1 );
315+ }
316+
317+ for (; ib < nb ; ++ ib ) {
318+ const block_mxfp4 * GGML_RESTRICT x0 = & x [ib + 0 ];
319+ const block_q8_0 * GGML_RESTRICT y0 = & y [ib + 0 ];
320+
321+ const uint8x16_t v_x = vec_xl (0 , x0 -> qs );
322+
323+ int8x16_t v_xl = (int8x16_t )vec_and (v_x , v_m );
324+ int8x16_t v_xh = (int8x16_t )vec_sr (v_x , 4 );
325+
326+ v_xl = vec_perm (v_k , v_k , (uchar8x16_t )v_xl );
327+ v_xh = vec_perm (v_k , v_k , (uchar8x16_t )v_xh );
328+
329+ const int8x16_t v_yl = vec_xl (0 , y0 -> qs );
330+ const int8x16_t v_yh = vec_xl (QK8_0 /2 , y0 -> qs );
331+
332+ const int32x4_t v_xy = ggml_vec_dot (ggml_vec_dot (vec_splats (0 ), v_xl , v_yl ), v_xh , v_yh );
333+
334+ sumf += GGML_E8M0_TO_FP32 (x0 -> e ) * GGML_CPU_FP16_TO_FP32 (y0 -> d ) * vec_hsum_i32x4 (v_xy );
335+ }
336+ #else
337+ UNUSED (x );
338+ UNUSED (y );
339+ UNUSED (ib );
340+ UNUSED (sumf );
341+ ggml_vec_dot_mxfp4_q8_0_generic (n , s , bs , vx , bx , vy , by , nrc );
342+ #endif
343+ }
344+
263345void ggml_vec_dot_q5_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
264346 const int qk = QK8_0 ;
265347 const int nb = n / qk ;
0 commit comments