@@ -355,27 +355,33 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float s, cons
355355#if  defined(GGML_USE_ACCELERATE )
356356    vDSP_vsmsa (y , 1 , & s , & b , y , 1 , n );
357357#elif  defined(GGML_SIMD )
358-     // TODO: #if defined(__ARM_FEATURE_SVE) 
359-     const  int  np  =  (n  &  ~(GGML_F32_STEP  -  1 ));
358+     #if  defined(__ARM_FEATURE_SVE )
359+         // scalar ; TODO: Write SVE code 
360+         for  (int  i  =  0 ; i  <  n ; ++ i ) {
361+             y [i ] =  y [i ]* s  +  b ;
362+         }
363+     #else 
364+         const  int  np  =  (n  &  ~(GGML_F32_STEP  -  1 ));
360365
361-     GGML_F32_VEC  vs  =  GGML_F32_VEC_SET1 (s );
362-     GGML_F32_VEC  vb  =  GGML_F32_VEC_SET1 (b );
366+          GGML_F32_VEC  vs  =  GGML_F32_VEC_SET1 (s );
367+          GGML_F32_VEC  vb  =  GGML_F32_VEC_SET1 (b );
363368
364-     GGML_F32_VEC  ay [GGML_F32_ARR ];
369+          GGML_F32_VEC  ay [GGML_F32_ARR ];
365370
366-     for  (int  i  =  0 ; i  <  np ; i  +=  GGML_F32_STEP ) {
367-         for  (int  j  =  0 ; j  <  GGML_F32_ARR ; j ++ ) {
368-             ay [j ] =  GGML_F32_VEC_LOAD (y  +  i  +  j * GGML_F32_EPR );
369-             ay [j ] =  GGML_F32_VEC_FMA (ay [j ], vs , vb );
371+          for  (int  i  =  0 ; i  <  np ; i  +=  GGML_F32_STEP ) {
372+              for  (int  j  =  0 ; j  <  GGML_F32_ARR ; j ++ ) {
373+                  ay [j ] =  GGML_F32_VEC_LOAD (y  +  i  +  j * GGML_F32_EPR );
374+                  ay [j ] =  GGML_F32_VEC_FMA (ay [j ], vs , vb );
370375
371-             GGML_F32_VEC_STORE (y  +  i  +  j * GGML_F32_EPR , ay [j ]);
376+                 GGML_F32_VEC_STORE (y  +  i  +  j * GGML_F32_EPR , ay [j ]);
377+             }
372378        }
373-     }
374379
375-     // leftovers 
376-     for  (int  i  =  np ; i  <  n ; ++ i ) {
377-         y [i ] =  y [i ]* s  +  b ;
378-     }
380+         // leftovers 
381+         for  (int  i  =  np ; i  <  n ; ++ i ) {
382+             y [i ] =  y [i ]* s  +  b ;
383+         }
384+     #endif 
379385#else 
380386    // scalar 
381387    for  (int  i  =  0 ; i  <  n ; ++ i ) {
0 commit comments