@@ -397,119 +397,118 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
397397}
398398
399399inline static void ggml_vec_mad_f16 (const int n , ggml_fp16_t * GGML_RESTRICT y , const ggml_fp16_t * GGML_RESTRICT x , const float v ) {
400- #if defined(GGML_SIMD )
401- #if defined(__ARM_FEATURE_SVE )
402- const int sve_register_length = svcntb () * 8 ;
403- const int ggml_f16_epr = sve_register_length / 16 ;
404- const int ggml_f16_step = 8 * ggml_f16_epr ;
400+ #if defined(GGML_SIMD ) && defined(__ARM_FEATURE_SVE )
401+ const int sve_register_length = svcntb () * 8 ;
402+ const int ggml_f16_epr = sve_register_length / 16 ;
403+ const int ggml_f16_step = 8 * ggml_f16_epr ;
405404
406- GGML_F16x_VEC vx = GGML_F16x_VEC_SET1 (v );
405+ GGML_F16x_VEC vx = GGML_F16x_VEC_SET1 (v );
407406
408- const int np = (n & ~(ggml_f16_step - 1 ));
407+ int np = (n & ~(ggml_f16_step - 1 ));
409408
410- svfloat16_t ax1 , ax2 , ax3 , ax4 , ax5 , ax6 , ax7 , ax8 ;
411- svfloat16_t ay1 , ay2 , ay3 , ay4 , ay5 , ay6 , ay7 , ay8 ;
412- for (int i = 0 ; i < np ; i += ggml_f16_step ) {
413- ax1 = GGML_F16x_VEC_LOAD (x + i + 0 * ggml_f16_epr , 0 );
414- ay1 = GGML_F16x_VEC_LOAD (y + i + 0 * ggml_f16_epr , 0 );
415- ay1 = GGML_F16x_VEC_FMA (ay1 , ax1 , vx );
409+ svfloat16_t ax1 , ax2 , ax3 , ax4 , ax5 , ax6 , ax7 , ax8 ;
410+ svfloat16_t ay1 , ay2 , ay3 , ay4 , ay5 , ay6 , ay7 , ay8 ;
411+ for (int i = 0 ; i < np ; i += ggml_f16_step ) {
412+ ax1 = GGML_F16x_VEC_LOAD (x + i + 0 * ggml_f16_epr , 0 );
413+ ay1 = GGML_F16x_VEC_LOAD (y + i + 0 * ggml_f16_epr , 0 );
414+ ay1 = GGML_F16x_VEC_FMA (ay1 , ax1 , vx );
416415
417- GGML_F16x_VEC_STORE (y + i + 0 * ggml_f16_epr , ay1 , 0 );
416+ GGML_F16x_VEC_STORE (y + i + 0 * ggml_f16_epr , ay1 , 0 );
418417
419- ax2 = GGML_F16x_VEC_LOAD (x + i + 1 * ggml_f16_epr , 1 );
420- ay2 = GGML_F16x_VEC_LOAD (y + i + 1 * ggml_f16_epr , 1 );
421- ay2 = GGML_F16x_VEC_FMA (ay2 , ax2 , vx );
418+ ax2 = GGML_F16x_VEC_LOAD (x + i + 1 * ggml_f16_epr , 1 );
419+ ay2 = GGML_F16x_VEC_LOAD (y + i + 1 * ggml_f16_epr , 1 );
420+ ay2 = GGML_F16x_VEC_FMA (ay2 , ax2 , vx );
422421
423- GGML_F16x_VEC_STORE (y + i + 1 * ggml_f16_epr , ay2 , 1 );
422+ GGML_F16x_VEC_STORE (y + i + 1 * ggml_f16_epr , ay2 , 1 );
424423
425- ax3 = GGML_F16x_VEC_LOAD (x + i + 2 * ggml_f16_epr , 2 );
426- ay3 = GGML_F16x_VEC_LOAD (y + i + 2 * ggml_f16_epr , 2 );
427- ay3 = GGML_F16x_VEC_FMA (ay3 , ax3 , vx );
424+ ax3 = GGML_F16x_VEC_LOAD (x + i + 2 * ggml_f16_epr , 2 );
425+ ay3 = GGML_F16x_VEC_LOAD (y + i + 2 * ggml_f16_epr , 2 );
426+ ay3 = GGML_F16x_VEC_FMA (ay3 , ax3 , vx );
428427
429- GGML_F16x_VEC_STORE (y + i + 2 * ggml_f16_epr , ay3 , 2 );
428+ GGML_F16x_VEC_STORE (y + i + 2 * ggml_f16_epr , ay3 , 2 );
430429
431- ax4 = GGML_F16x_VEC_LOAD (x + i + 3 * ggml_f16_epr , 3 );
432- ay4 = GGML_F16x_VEC_LOAD (y + i + 3 * ggml_f16_epr , 3 );
433- ay4 = GGML_F16x_VEC_FMA (ay4 , ax4 , vx );
430+ ax4 = GGML_F16x_VEC_LOAD (x + i + 3 * ggml_f16_epr , 3 );
431+ ay4 = GGML_F16x_VEC_LOAD (y + i + 3 * ggml_f16_epr , 3 );
432+ ay4 = GGML_F16x_VEC_FMA (ay4 , ax4 , vx );
434433
435- GGML_F16x_VEC_STORE (y + i + 3 * ggml_f16_epr , ay4 , 3 );
434+ GGML_F16x_VEC_STORE (y + i + 3 * ggml_f16_epr , ay4 , 3 );
436435
437- ax5 = GGML_F16x_VEC_LOAD (x + i + 4 * ggml_f16_epr , 4 );
438- ay5 = GGML_F16x_VEC_LOAD (y + i + 4 * ggml_f16_epr , 4 );
439- ay5 = GGML_F16x_VEC_FMA (ay5 , ax5 , vx );
436+ ax5 = GGML_F16x_VEC_LOAD (x + i + 4 * ggml_f16_epr , 4 );
437+ ay5 = GGML_F16x_VEC_LOAD (y + i + 4 * ggml_f16_epr , 4 );
438+ ay5 = GGML_F16x_VEC_FMA (ay5 , ax5 , vx );
440439
441- GGML_F16x_VEC_STORE (y + i + 4 * ggml_f16_epr , ay5 , 4 );
440+ GGML_F16x_VEC_STORE (y + i + 4 * ggml_f16_epr , ay5 , 4 );
442441
443- ax6 = GGML_F16x_VEC_LOAD (x + i + 5 * ggml_f16_epr , 5 );
444- ay6 = GGML_F16x_VEC_LOAD (y + i + 5 * ggml_f16_epr , 5 );
445- ay6 = GGML_F16x_VEC_FMA (ay6 , ax6 , vx );
442+ ax6 = GGML_F16x_VEC_LOAD (x + i + 5 * ggml_f16_epr , 5 );
443+ ay6 = GGML_F16x_VEC_LOAD (y + i + 5 * ggml_f16_epr , 5 );
444+ ay6 = GGML_F16x_VEC_FMA (ay6 , ax6 , vx );
446445
447- GGML_F16x_VEC_STORE (y + i + 5 * ggml_f16_epr , ay6 , 5 );
446+ GGML_F16x_VEC_STORE (y + i + 5 * ggml_f16_epr , ay6 , 5 );
448447
449- ax7 = GGML_F16x_VEC_LOAD (x + i + 6 * ggml_f16_epr , 6 );
450- ay7 = GGML_F16x_VEC_LOAD (y + i + 6 * ggml_f16_epr , 6 );
451- ay7 = GGML_F16x_VEC_FMA (ay7 , ax7 , vx );
448+ ax7 = GGML_F16x_VEC_LOAD (x + i + 6 * ggml_f16_epr , 6 );
449+ ay7 = GGML_F16x_VEC_LOAD (y + i + 6 * ggml_f16_epr , 6 );
450+ ay7 = GGML_F16x_VEC_FMA (ay7 , ax7 , vx );
452451
453- GGML_F16x_VEC_STORE (y + i + 6 * ggml_f16_epr , ay7 , 6 );
452+ GGML_F16x_VEC_STORE (y + i + 6 * ggml_f16_epr , ay7 , 6 );
454453
455- ax8 = GGML_F16x_VEC_LOAD (x + i + 7 * ggml_f16_epr , 7 );
456- ay8 = GGML_F16x_VEC_LOAD (y + i + 7 * ggml_f16_epr , 7 );
457- ay8 = GGML_F16x_VEC_FMA (ay8 , ax8 , vx );
454+ ax8 = GGML_F16x_VEC_LOAD (x + i + 7 * ggml_f16_epr , 7 );
455+ ay8 = GGML_F16x_VEC_LOAD (y + i + 7 * ggml_f16_epr , 7 );
456+ ay8 = GGML_F16x_VEC_FMA (ay8 , ax8 , vx );
458457
459- GGML_F16x_VEC_STORE (y + i + 7 * ggml_f16_epr , ay8 , 7 );
460- }
461- const int np2 = (n & ~(ggml_f16_epr - 1 ));
462- for (int k = np ; k < np2 ; k += ggml_f16_epr ) {
463- svfloat16_t rx = GGML_F16x_VEC_LOAD (x + k , 0 );
464- svfloat16_t ry = GGML_F16x_VEC_LOAD (y + k , 0 );
465- ry = GGML_F16x_VEC_FMA (ry , rx , vx );
466-
467- GGML_F16x_VEC_STORE (y + k , ry , 0 );
468- }
469-
470- if (np2 < n ) {
471- svbool_t pg = svwhilelt_b16 (np2 , n );
472- svfloat16_t hx = svld1_f16 (pg , (const __fp16 * )(x + np2 ));
473- svfloat16_t hy = svld1_f16 (pg , (const __fp16 * )(y + np2 ));
474- hy = svmad_f16_x (pg , hx , vx , hy );
475- svst1_f16 (pg , (__fp16 * )(y + np2 ), hy );
476- }
458+ GGML_F16x_VEC_STORE (y + i + 7 * ggml_f16_epr , ay8 , 7 );
459+ }
460+ const int np2 = (n & ~(ggml_f16_epr - 1 ));
461+ for (int k = np ; k < np2 ; k += ggml_f16_epr ) {
462+ svfloat16_t rx = GGML_F16x_VEC_LOAD (x + k , 0 );
463+ svfloat16_t ry = GGML_F16x_VEC_LOAD (y + k , 0 );
464+ ry = GGML_F16x_VEC_FMA (ry , rx , vx );
477465
478- #elif defined(__riscv_v_intrinsic )
479- // todo: RVV impl
480- // scalar
481- for (int i = 0 ; i < n ; ++ i ) {
482- y [i ] = GGML_CPU_FP32_TO_FP16 (GGML_CPU_FP16_TO_FP32 (y [i ]) + GGML_CPU_FP16_TO_FP32 (x [i ])* v );
483- }
484- #else
485- const int np = (n & ~(GGML_F16_STEP - 1 ));
466+ GGML_F16x_VEC_STORE (y + k , ry , 0 );
467+ }
486468
487- GGML_F16_VEC vx = GGML_F16_VEC_SET1 (v );
469+ if (np2 < n ) {
470+ svbool_t pg = svwhilelt_b16 (np2 , n );
471+ svfloat16_t hx = svld1_f16 (pg , (const __fp16 * )(x + np2 ));
472+ svfloat16_t hy = svld1_f16 (pg , (const __fp16 * )(y + np2 ));
473+ hy = svmad_f16_x (pg , hx , vx , hy );
474+ svst1_f16 (pg , (__fp16 * )(y + np2 ), hy );
475+ }
476+ np = n ;
477+ #elif defined(__riscv_zvfh ) // implies __riscv_v_intrinsic
478+ const int np = n ;
479+ _Float16 hv = (_Float16 )v ;
480+ for (int i = 0 , avl ; i < n ; i += avl ) {
481+ avl = __riscv_vsetvl_e16m8 (n - i );
482+ vfloat16m8_t ax = __riscv_vle16_v_f16m8 ((const _Float16 * )& x [i ], avl );
483+ vfloat16m8_t ay = __riscv_vle16_v_f16m8 ((_Float16 * )& y [i ], avl );
484+ vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8 (ax , hv , ay , avl );
485+ __riscv_vse16_v_f16m8 ((_Float16 * )& y [i ], ny , avl );
486+ }
487+ #elif defined(GGML_SIMD )
488+ const int np = (n & ~(GGML_F16_STEP - 1 ));
488489
489- GGML_F16_VEC ax [GGML_F16_ARR ];
490- GGML_F16_VEC ay [GGML_F16_ARR ];
490+ GGML_F16_VEC vx = GGML_F16_VEC_SET1 (v );
491491
492- for (int i = 0 ; i < np ; i += GGML_F16_STEP ) {
493- for (int j = 0 ; j < GGML_F16_ARR ; j ++ ) {
494- ax [j ] = GGML_F16_VEC_LOAD (x + i + j * GGML_F16_EPR , j );
495- ay [j ] = GGML_F16_VEC_LOAD (y + i + j * GGML_F16_EPR , j );
496- ay [j ] = GGML_F16_VEC_FMA (ay [j ], ax [j ], vx );
492+ GGML_F16_VEC ax [GGML_F16_ARR ];
493+ GGML_F16_VEC ay [GGML_F16_ARR ];
497494
498- GGML_F16_VEC_STORE (y + i + j * GGML_F16_EPR , ay , j );
499- }
500- }
495+ for (int i = 0 ; i < np ; i += GGML_F16_STEP ) {
496+ for (int j = 0 ; j < GGML_F16_ARR ; j ++ ) {
497+ ax [j ] = GGML_F16_VEC_LOAD (x + i + j * GGML_F16_EPR , j );
498+ ay [j ] = GGML_F16_VEC_LOAD (y + i + j * GGML_F16_EPR , j );
499+ ay [j ] = GGML_F16_VEC_FMA (ay [j ], ax [j ], vx );
501500
502- // leftovers
503- for (int i = np ; i < n ; ++ i ) {
504- y [i ] = GGML_CPU_FP32_TO_FP16 (GGML_CPU_FP16_TO_FP32 (y [i ]) + GGML_CPU_FP16_TO_FP32 (x [i ])* v );
501+ GGML_F16_VEC_STORE (y + i + j * GGML_F16_EPR , ay , j );
505502 }
506- #endif
503+ }
507504#else
508- // scalar
509- for (int i = 0 ; i < n ; ++ i ) {
505+ const int np = 0 ;
506+ #endif
507+
508+ // leftovers
509+ for (int i = np ; i < n ; ++ i ) {
510510 y [i ] = GGML_CPU_FP32_TO_FP16 (GGML_CPU_FP16_TO_FP32 (y [i ]) + GGML_CPU_FP16_TO_FP32 (x [i ])* v );
511511 }
512- #endif
513512}
514513
515514// xs and vs are byte strides of x and v
0 commit comments