@@ -278,6 +278,72 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
278278#endif
279279}
280280
281+ void ggml_vec_dot_mxfp4_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
282+ assert (nrc == 1 );
283+ UNUSED (nrc );
284+ UNUSED (bx );
285+ UNUSED (by );
286+ UNUSED (bs );
287+ assert (n % QK_MXFP4 == 0 );
288+ static_assert (QK_MXFP4 == QK8_0 , "QK_MXFP4 and QK8_0 must be the same" );
289+
290+ const block_mxfp4 * GGML_RESTRICT x = vx ;
291+ const block_q8_0 * GGML_RESTRICT y = vy ;
292+
293+ const int nb = n / QK_MXFP4 ;
294+
295+ int ib = 0 ;
296+ float sumf = 0 ;
297+
298+ #if defined(__POWER9_VECTOR__ )
299+ const vector signed char lowMask = vec_splats ((signed char )0xF );
300+ const vector unsigned char vshift4 = vec_splats ((unsigned char )4 );
301+ vector float vsumf0 = vec_splats (0.0f );
302+
303+ vector signed char kv = vec_xl (0 , (const signed char * )kvalues_mxfp4 );
304+
305+ #pragma GCC unroll 8
306+ for (; ib < nb ; ++ ib ) {
307+ __builtin_prefetch (x [ib ].qs , 0 , 1 );
308+ __builtin_prefetch (y [ib ].qs , 0 , 1 );
309+
310+ vector float vyd = vec_splats (GGML_CPU_FP16_TO_FP32 (y [ib ].d ) *
311+ GGML_E8M0_TO_FP32_HALF (x [ib ].e ));
312+
313+ vector signed char q8y0 = vec_xl ( 0 , y [ib ].qs );
314+ vector signed char q8y1 = vec_xl (16 , y [ib ].qs );
315+
316+ vector signed char qxs = (vector signed char )vec_xl (0 , x [ib ].qs );
317+
318+ vector unsigned char lo_nibbles = (vector unsigned char )vec_and (qxs , lowMask );
319+ vector unsigned char hi_nibbles = (vector unsigned char )vec_sr (qxs , vshift4 );
320+
321+ vector signed char q4x0 = vec_perm (kv , kv , lo_nibbles );
322+ vector signed char q4x1 = vec_perm (kv , kv , hi_nibbles );
323+
324+ vector signed short qv0 = vec_add (vec_mule (q4x0 , q8y0 ), vec_mulo (q4x0 , q8y0 ));
325+ vector signed short qv1 = vec_add (vec_mule (q4x1 , q8y1 ), vec_mulo (q4x1 , q8y1 ));
326+
327+ vector signed int vsumi0 = vec_splats ((int32_t )0 );
328+ vsumi0 = vec_sum4s (qv0 , vsumi0 );
329+ vsumi0 = vec_sum4s (qv1 , vsumi0 );
330+
331+ vsumf0 = vec_madd (vec_ctf (vsumi0 , 0 ), vyd , vsumf0 );
332+ }
333+
334+ vsumf0 = vec_add (vsumf0 , vec_sld (vsumf0 , vsumf0 , 4 ));
335+ vsumf0 = vec_add (vsumf0 , vec_sld (vsumf0 , vsumf0 , 8 ));
336+ sumf = vec_extract (vsumf0 , 0 );
337+ * s = sumf ;
338+ #else
339+ UNUSED (x );
340+ UNUSED (y );
341+ UNUSED (ib );
342+ UNUSED (sumf );
343+ ggml_vec_dot_mxfp4_q8_0_generic (n , s , bs , vx , bx , vy , by , nrc );
344+ #endif
345+ }
346+
281347void ggml_vec_dot_q5_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
282348 const int qk = QK8_0 ;
283349 const int nb = n / qk ;
0 commit comments