11#include  "ggml-dsp.h" 
22
3- inline  static  void  ggmlhexagon_dsp_add_f32  (const  int  n , float  *  z , const  float  *  x , const  float  *  y ) {
3+ static  inline  void  l2fetch (const  void  *  p , uint32_t  stride ,
4+                            uint32_t  width , uint32_t  height ,
5+                            uint32_t  dir ) {
6+     uint64_t  control  =  HEXAGON_V64_CREATE_H (dir , stride , width , height );
7+     __asm__ __volatile__ (" l2fetch(%0,%1) "  : :"r" (p ),"r" (control ));
8+ }
9+ 
10+ static  inline  void  ggmlhexagon_dsp_add_f32 (const  int  n , float  *  GGML_RESTRICT  z , const  float  *  GGML_RESTRICT  x , const  float  *  GGML_RESTRICT  y ) {
411    HVX_Vector  *  va ;
512    HVX_Vector  *  vb ;
613    HVX_Vector  *  vc ;
714    HVX_Vector  qf32 ;
8-     const  int  FLOATS_PER_VECTOR  =  128  / sizeof (float );
9-     const  int  block   =  n  / FLOATS_PER_VECTOR ;
10-     const  int  left    =  n  % FLOATS_PER_VECTOR ;
11-     const  int  blocks  =  block  *  FLOATS_PER_VECTOR ;
15+     const  size_t  FLOATS_PER_VECTOR  =  128  / sizeof (float );
16+     const  size_t  block   =  n  / FLOATS_PER_VECTOR ;
17+     const  size_t  left    =  n  % FLOATS_PER_VECTOR ;
18+     const  size_t  blocks  =  block  *  FLOATS_PER_VECTOR ;
1219
1320    if  ((((uintptr_t )z  | (uintptr_t )x  | (uintptr_t )y ) % ALIGN_128_BYTE ) !=  0 ) {
1421        GGMLHEXAGON_LOG_DEBUG ("memaddress mismatch alignment 128 bytes z:%p x:%p y:%p" , z , x , y );
@@ -21,11 +28,13 @@ inline static void ggmlhexagon_dsp_add_f32 (const int n, float * z, const float
2128    va  =  (HVX_Vector  * )x ;
2229    vb  =  (HVX_Vector  * )y ;
2330    vc  =  (HVX_Vector  * )z ;
31+     //unroll is better but need more carefully check for various cases and I think DSP also don't like branch predication 
2432    for  (size_t  i  =  0 ; i  <  block ; ++ i ) {
33+         l2fetch (va  +  VLEN , VLEN , VLEN , 1 , 0 );
34+         l2fetch (vb  +  VLEN , VLEN , VLEN , 1 , 0 );
2535        //*vc++ = Q6_Vsf_vadd_VsfVsf(*va++, *vb++); 
2636        qf32  =  Q6_Vqf32_vadd_VsfVsf (* va ++ , * vb ++ );
27-         * vc  =  Q6_Vsf_equals_Vqf32 (qf32 );
28-         vc ++ ;
37+         * vc ++  =  Q6_Vsf_equals_Vqf32 (qf32 );
2938    }
3039
3140    if  (left  >  0 ) {
@@ -49,6 +58,17 @@ static void ggml_compute_forward_add_f32(
4958
5059    GGML_ASSERT (ggml_can_repeat (src1 , src0 ) &&  ggml_are_same_shape (src0 , dst ));
5160
61+     const  int  rank  =  ggml_n_dims (src0 );
62+     if  (1  ==  rank ) {
63+         //element-wise addition with vector 
64+         const  size_t  len  =  src0 -> ne [0 ];
65+         float  *  dst_ptr   =  (float  * ) (dst -> data );
66+         float  *  src0_ptr  =  (float  * ) (src0 -> data );
67+         float  *  src1_ptr  =  (float  * ) (src1 -> data );
68+         ggmlhexagon_dsp_add_f32 (len , dst_ptr , src0_ptr , src1_ptr );
69+         return ;
70+     }
71+ 
5272    const  int  ith  =  0 ;
5373    const  int  nth  =  1 ;
5474
@@ -115,24 +135,9 @@ static void ggml_compute_forward_add_f32(
115135}
116136
117137//FIXME: why failed with test-backend-ops when disable ion rpc mempool 
118- int  ggmlop_dsp_add (remote_handle64  h , const  ggml_tensor  *  src0 , const  ggml_tensor  *  src1 , ggml_tensor  *  dst )
119- {
138+ int  ggmlop_dsp_add (remote_handle64  h , const  ggml_tensor  *  src0 , const  ggml_tensor  *  src1 , ggml_tensor  *  dst ) {
120139    GGMLHEXAGON_LOG_DEBUG ("enter %s\n" , __func__ );
121-     switch  (src0 -> type ) {
122-         case  GGML_TYPE_F32 :
123-         {
124-             if  (src1 -> type  ==  GGML_TYPE_F32 ) {
125-                 ggml_compute_forward_add_f32 (src0 , src1 , dst );
126-             } else  {
127-                 GGML_ABORT ("fatal error" );
128-             }
129-             break ;
130-         }
131-         default :
132-         {
133-             GGML_ABORT ("fatal error" );
134-         }
135-     }
140+     ggml_compute_forward_add_f32 (src0 , src1 , dst );
136141    GGMLHEXAGON_LOG_DEBUG ("leave %s\n" , __func__ );
137142    return  0 ;
138143}
0 commit comments