@@ -11642,79 +11642,191 @@ static void ggml_compute_forward_add_rel_pos(
1164211642    }
1164311643}
1164411644
11645- // ggml_compute_forward_rwkv_wkv  
11645+ // ggml_compute_forward_rwkv_wkv6  
1164611646
11647- static  void  ggml_compute_forward_rwkv_wkv_f32 (
11647+ static  void  ggml_compute_forward_rwkv_wkv6_f32 (
1164811648        const  struct  ggml_compute_params  *  params ,
1164911649        struct  ggml_tensor  *  dst ) {
11650-     const  size_t  T  =  dst -> src [1 ]-> ne [3 ];
11651-     const  size_t  C  =  dst -> ne [0 ];
11652-     const  size_t  H  =  dst -> src [1 ]-> ne [2 ];
11653-     const  size_t  n_seqs  =  dst -> src [5 ]-> ne [1 ];
11650+     const  int64_t  T  =  dst -> src [1 ]-> ne [3 ];
11651+     const  int64_t  C  =  dst -> ne [0 ];
11652+     const  int64_t  HEADS  =  dst -> src [1 ]-> ne [2 ];
11653+     const  int64_t  n_seqs  =  dst -> src [5 ]-> ne [1 ];
11654+     const  int64_t  head_size  =  C  / HEADS ;
1165411655
1165511656    float  *  dst_data  =  (float  * ) dst -> data ;
1165611657    float  *  state  =  ((float  * ) dst -> data ) +  C  *  T ;
1165711658
11658-     if  (params -> ith  !=  0 ) {
11659+     const  int  ith  =  params -> ith ;
11660+     const  int  nth  =  params -> nth ;
11661+ 
11662+     if  (ith  >= HEADS ) {
1165911663        return ;
1166011664    }
1166111665
11662-     memset (dst_data , 0 , T  *  C  *  sizeof (float ));
11666+     const  int  h_start  =  (HEADS  *  ith ) / nth ;
11667+     const  int  h_end  =  ((HEADS  *  (ith  +  1 )) / nth  <  HEADS ) ?
11668+                 (HEADS  *  (ith  +  1 )) / nth  : HEADS ;
1166311669
1166411670    float  *  k  =           (float  * ) dst -> src [0 ]-> data ;
1166511671    float  *  v  =           (float  * ) dst -> src [1 ]-> data ;
1166611672    float  *  r  =           (float  * ) dst -> src [2 ]-> data ;
1166711673    float  *  time_faaaa  =  (float  * ) dst -> src [3 ]-> data ;
1166811674    float  *  time_decay  =  (float  * ) dst -> src [4 ]-> data ;
1166911675
11670-     size_t  t_stride  =  H  *  ( C  /  H ); 
11676+     size_t  t_stride  =  HEADS  *  head_size ;  // Same to C 
1167111677
11672-     size_t  h_stride  =  C  / H ;
11673-     size_t  h_stride_2d  =  (C  / H ) *  (C  / H );
11678+     size_t  h_stride  =  C  / HEADS ;
11679+     GGML_ASSERT (C  % HEADS  ==  0 ); // C must be divisible by HEADS 
11680+     size_t  h_stride_2d  =  head_size  *  head_size ;
1167411681
11675-     // basically fused operations: 
11676-     // dst = r @ (time_faaaa * (k @ v) + state), 
11677-     // state = time_decay * state + (k @ v), 
11678-     // recursive through each token 
11679-     for  (size_t  t  =  0 ; t  <  T ; t ++ ) {
11680-         size_t  t_offset  =  t  *  t_stride ;
11681-         size_t  state_offset  =  (C  / H ) *  C  *  (t  / (T  / n_seqs ));
11682-         float  *  state_cur  =  state  +  state_offset ;
11683-         float  *  state_prev  =  t  % (T  / n_seqs ) ? state_cur  : (float * )dst -> src [5 ]-> data  +  state_offset ;
11682+     if  (ith  ==  0 ) {
11683+         memset (dst_data , 0 , T  *  C  *  sizeof (float ));
11684+     }
11685+     ggml_barrier (params -> threadpool );
1168411686
11685-         for  (size_t  h  =  0 ; h  <  H ; h ++ ) {
11686-             size_t  h_offset  =  h  *  h_stride ;
11687-             size_t  t_h_offset  =  t_offset  +  h_offset ;
11688-             size_t  h_2d_offset  =  h  *  h_stride_2d ;
1168911687
11690-             for  (size_t  i  =  0 ; i  <  C  / H ; i ++ ) {
11691-                 size_t  t_h_i_offset  =  t_h_offset  +  i ;
11692-                 size_t  h_i_offset  =  h_offset  +  i ;
11693-                 size_t  h_2d_i_offset  =  h_2d_offset  +  i  *  h_stride ;
11688+     #if  defined(__AVX__ ) &&  !defined(__AVX512F__ )
11689+         #define  GGML_F32X  GGML_F32x8
11690+         #define  GGML_F32X_SET1  GGML_F32x8_SET1
11691+         #define  GGML_F32X_LOAD  GGML_F32x8_LOAD
11692+         #define  GGML_F32X_STORE  GGML_F32x8_STORE
11693+         #define  GGML_F32X_MUL  GGML_F32x8_MUL
11694+         #define  GGML_F32X_FMA  GGML_F32x8_FMA
11695+         #define  WKV_VECTOR_SIZE  8
11696+     #elif  defined(__AVX512F__ )
11697+         #define  GGML_F32X  GGML_F32x16
11698+         #define  GGML_F32X_SET1  GGML_F32x16_SET1
11699+         #define  GGML_F32X_LOAD  GGML_F32x16_LOAD
11700+         #define  GGML_F32X_STORE  GGML_F32x16_STORE
11701+         #define  GGML_F32X_MUL  GGML_F32x16_MUL
11702+         #define  GGML_F32X_FMA  GGML_F32x16_FMA
11703+         #define  WKV_VECTOR_SIZE  16
11704+     #elif  defined(__ARM_NEON ) &&  defined(__aarch64__ )
11705+         #define  GGML_F32X  GGML_F32x4
11706+         #define  GGML_F32X_SET1  GGML_F32x4_SET1
11707+         #define  GGML_F32X_LOAD  GGML_F32x4_LOAD
11708+         #define  GGML_F32X_STORE  GGML_F32x4_STORE
11709+         #define  GGML_F32X_MUL  GGML_F32x4_MUL
11710+         #define  GGML_F32X_FMA  GGML_F32x4_FMA
11711+         #define  WKV_VECTOR_SIZE  4
11712+     #endif 
1169411713
11695-                 float  k_val  =  k [t_h_i_offset ];
11696-                 float  r_val  =  r [t_h_i_offset ];
11697-                 float  time_faaaa_val  =  time_faaaa [h_i_offset ];
11698-                 // RWKV v6: different time_decay for each token. 
11699-                 float  time_decay_val  =  time_decay [t_h_i_offset ];
11714+     #ifdef  WKV_VECTOR_SIZE 
11715+         const  int64_t  vec_count  =  head_size  / WKV_VECTOR_SIZE ;
11716+ 
11717+         for  (int64_t  t  =  0 ; t  <  T ; t ++ ) {
11718+             size_t  t_offset  =  t  *  t_stride ;
11719+             size_t  state_offset  =  head_size  *  C  *  (t  / (T  / n_seqs ));
11720+             float  *  state_cur  =  state  +  state_offset ;
11721+             float  *  state_prev  =  t  % (T  / n_seqs ) ? state_cur  : (float * )dst -> src [5 ]-> data  +  state_offset ;
11722+ 
11723+             for  (int64_t  h  =  h_start ; h  <  h_end ; h ++ ) {
11724+                 size_t  h_offset  =  h  *  h_stride ;
11725+                 size_t  t_h_offset  =  t_offset  +  h_offset ;
11726+                 size_t  h_2d_offset  =  h  *  h_stride_2d ;
11727+ 
11728+                 for  (int64_t  i  =  0 ; i  <  head_size ; i ++ ) {
11729+                     size_t  t_h_i_offset  =  t_h_offset  +  i ;
11730+                     size_t  h_i_offset  =  h_offset  +  i ;
11731+                     size_t  h_2d_i_offset  =  h_2d_offset  +  i  *  h_stride ;
11732+ 
11733+                     float  k_val  =  k [t_h_i_offset ];
11734+                     float  r_val  =  r [t_h_i_offset ];
11735+                     float  time_faaaa_val  =  time_faaaa [h_i_offset ];
11736+                     float  time_decay_val  =  time_decay [t_h_i_offset ];
11737+ 
11738+                     // Broadcast scalar values to vectors 
11739+                     GGML_F32X  k_vec  =  GGML_F32X_SET1 (k_val );
11740+                     GGML_F32X  r_vec  =  GGML_F32X_SET1 (r_val );
11741+                     GGML_F32X  time_faaaa_vec  =  GGML_F32X_SET1 (time_faaaa_val );
11742+                     GGML_F32X  time_decay_vec  =  GGML_F32X_SET1 (time_decay_val );
11743+ 
11744+                     for  (int64_t  j  =  0 ; j  <  vec_count ; j ++ ) {
11745+                         size_t  base_j  =  j  *  WKV_VECTOR_SIZE ;
11746+                         size_t  t_h_j_offset  =  t_h_offset  +  base_j ;
11747+                         size_t  h_2d_i_j_offset  =  h_2d_i_offset  +  base_j ;
11748+ 
11749+                         // Load x elements at once 
11750+                         GGML_F32X  v_vec  =  GGML_F32X_LOAD (& v [t_h_j_offset ]);
11751+                         GGML_F32X  prev_state_vec  =  GGML_F32X_LOAD (& state_prev [h_2d_i_j_offset ]);
11752+                         GGML_F32X  dst_vec  =  GGML_F32X_LOAD (& dst_data [t_h_j_offset ]);
11753+ 
11754+                         // Compute kv = v * k 
11755+                         GGML_F32X  kv_vec  =  GGML_F32X_MUL (v_vec , k_vec );
11756+ 
11757+                         // Compute temp = kv * time_faaaa + prev_state 
11758+                         GGML_F32X  temp_vec  =  GGML_F32X_FMA (prev_state_vec , kv_vec , time_faaaa_vec );
11759+ 
11760+                         // Update dst: dst += temp * r 
11761+                         dst_vec  =  GGML_F32X_FMA (dst_vec , temp_vec , r_vec );
11762+                         GGML_F32X_STORE (& dst_data [t_h_j_offset ], dst_vec );
11763+ 
11764+                         // Update state: state = prev_state * time_decay + kv 
11765+                         GGML_F32X  new_state_vec  =  GGML_F32X_FMA (kv_vec , prev_state_vec , time_decay_vec );
11766+                         GGML_F32X_STORE (& state_cur [h_2d_i_j_offset ], new_state_vec );
11767+                     }
1170011768
11701-                 for  (size_t  j  =  0 ; j  <  C  / H ; j  ++ ) {
11702-                     size_t  t_h_j_offset  =  t_h_offset  +  j ;
11703-                     size_t  h_2d_i_j_offset  =  h_2d_i_offset  +  j ;
11769+                     // Handle remaining elements, this will not be used. 
11770+                     for  (int64_t  j  =  vec_count  *  WKV_VECTOR_SIZE ; j  <  head_size ; j ++ ) {
11771+                         size_t  t_h_j_offset  =  t_h_offset  +  j ;
11772+                         size_t  h_2d_i_j_offset  =  h_2d_i_offset  +  j ;
11773+                         float  v_val  =  v [t_h_j_offset ];
11774+                         float  kv_val  =  v_val  *  k_val ;
11775+                         float  prev_state_val  =  state_prev [h_2d_i_j_offset ];
11776+                         float  temp_val  =  kv_val  *  time_faaaa_val  +  prev_state_val ;
11777+                         dst_data [t_h_j_offset ] +=  temp_val  *  r_val ;
11778+                         state_cur [h_2d_i_j_offset ] =  prev_state_val  *  time_decay_val  +  kv_val ;
11779+                     }
11780+                 }
11781+             }
11782+         }
1170411783
11705-                     float  v_val  =  v [t_h_j_offset ];
11706-                     float  kv_val  =  v_val  *  k_val ;
11707-                     float  prev_state_val  =  state_prev [h_2d_i_j_offset ];
11708-                     float  temp_val  =  kv_val  *  time_faaaa_val  +  prev_state_val ;
11709-                     dst_data [t_h_j_offset ] +=  temp_val  *  r_val ;
11710-                     state_cur [h_2d_i_j_offset ] =  prev_state_val  *  time_decay_val  +  kv_val ;
11784+     #else 
11785+         // basically fused operations: 
11786+         // dst = r @ (time_faaaa * (k @ v) + state), 
11787+         // state = time_decay * state + (k @ v), 
11788+         // recursive through each token 
11789+         for  (int64_t  t  =  0 ; t  <  T ; t ++ ) {
11790+             size_t  t_offset  =  t  *  t_stride ;
11791+             size_t  state_offset  =  head_size  *  C  *  (t  / (T  / n_seqs ));
11792+             float  *  state_cur  =  state  +  state_offset ;
11793+             float  *  state_prev  =  t  % (T  / n_seqs ) ? state_cur  : (float * )dst -> src [5 ]-> data  +  state_offset ;
11794+ 
11795+             for  (int64_t  h  =  h_start ; h  <  h_end ; h ++ ) {
11796+                 size_t  h_offset  =  h  *  h_stride ;
11797+                 size_t  t_h_offset  =  t_offset  +  h_offset ;
11798+                 size_t  h_2d_offset  =  h  *  h_stride_2d ;
11799+ 
11800+                 for  (int64_t  i  =  0 ; i  <  head_size ; i ++ ) {
11801+                     size_t  t_h_i_offset  =  t_h_offset  +  i ;
11802+                     size_t  h_i_offset  =  h_offset  +  i ;
11803+                     size_t  h_2d_i_offset  =  h_2d_offset  +  i  *  h_stride ;
11804+ 
11805+                     float  k_val  =  k [t_h_i_offset ];
11806+                     float  r_val  =  r [t_h_i_offset ];
11807+                     float  time_faaaa_val  =  time_faaaa [h_i_offset ];
11808+                     // RWKV v6: different time_decay for each token. 
11809+                     float  time_decay_val  =  time_decay [t_h_i_offset ];
11810+ 
11811+                     for  (int64_t  j  =  0 ; j  <  head_size ; j ++ ) {
11812+                         size_t  t_h_j_offset  =  t_h_offset  +  j ;
11813+                         size_t  h_2d_i_j_offset  =  h_2d_i_offset  +  j ;
11814+ 
11815+                         float  v_val  =  v [t_h_j_offset ];
11816+                         float  kv_val  =  v_val  *  k_val ;
11817+                         float  prev_state_val  =  state_prev [h_2d_i_j_offset ];
11818+                         float  temp_val  =  kv_val  *  time_faaaa_val  +  prev_state_val ;
11819+                         dst_data [t_h_j_offset ] +=  temp_val  *  r_val ;
11820+                         state_cur [h_2d_i_j_offset ] =  prev_state_val  *  time_decay_val  +  kv_val ;
11821+                     }
1171111822                }
1171211823            }
1171311824        }
11714-     } 
11825+     #endif 
1171511826}
1171611827
11717- static  void  ggml_compute_forward_rwkv_wkv (
11828+ 
11829+ static  void  ggml_compute_forward_rwkv_wkv6 (
1171811830        const  struct  ggml_compute_params  *  params ,
1171911831        struct  ggml_tensor  *  dst ) {
1172011832
@@ -11723,7 +11835,7 @@ static void ggml_compute_forward_rwkv_wkv(
1172311835    switch  (src0 -> type ) {
1172411836        case  GGML_TYPE_F32 :
1172511837            {
11726-                 ggml_compute_forward_rwkv_wkv_f32 (params , dst );
11838+                 ggml_compute_forward_rwkv_wkv6_f32 (params , dst );
1172711839            } break ;
1172811840        default :
1172911841            {
@@ -12475,9 +12587,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1247512587            {
1247612588                ggml_compute_forward_add_rel_pos (params , tensor );
1247712589            } break ;
12478-         case  GGML_OP_RWKV_WKV :
12590+         case  GGML_OP_RWKV_WKV6 :
1247912591            {
12480-                 ggml_compute_forward_rwkv_wkv (params , tensor );
12592+                 ggml_compute_forward_rwkv_wkv6 (params , tensor );
1248112593            } break ;
1248212594        case  GGML_OP_MAP_UNARY :
1248312595            {
@@ -12775,7 +12887,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1277512887        case  GGML_OP_WIN_PART :
1277612888        case  GGML_OP_WIN_UNPART :
1277712889        case  GGML_OP_GET_REL_POS :
12778-         case  GGML_OP_RWKV_WKV :
12890+         case  GGML_OP_RWKV_WKV6 :
1277912891        case  GGML_OP_MAP_UNARY :
1278012892        case  GGML_OP_MAP_BINARY :
1278112893        case  GGML_OP_MAP_CUSTOM1_F32 :
0 commit comments