@@ -8409,83 +8409,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
84098409    int64_t  h_stride_2d = head_size * head_size;
84108410
84118411    #if  defined(GGML_SIMD)
8412-         for  (int64_t  t = 0 ; t < T; t++) {
8413-             int64_t  t_offset = t * t_stride;
8414-             int64_t  state_offset = head_size * C * (t / (T / n_seqs));
8415-             float  * state_cur = state + state_offset;
8416-             float  * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data  + state_offset;
8417- 
8418-             for  (int64_t  h = h_start; h < h_end; h++) {
8419-                 int64_t  h_offset = h * h_stride;
8420-                 int64_t  t_h_offset = t_offset + h_offset;
8421-                 int64_t  h_2d_offset = h * h_stride_2d;
8422- 
8423-                 for  (int64_t  ii = 0 ; ii < head_size; ii++) {
8424-                     int64_t  t_h_i_offset = t_h_offset + ii;
8425-                     int64_t  h_2d_i_offset = h_2d_offset + ii * h_stride;
8426- 
8427-                     GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
8412+         #if  defined(__ARM_FEATURE_SVE)
8413+             //  scalar Route to scalar implementation       //TODO: Write SVE code
8414+             for  (int64_t  t = 0 ; t < T; t++) {
8415+                 int64_t  t_offset = t * t_stride;
8416+                 int64_t  state_offset = head_size * C * (t / (T / n_seqs));
8417+                 float  * state_cur = state + state_offset;
8418+                 float  * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data  + state_offset;
8419+ 
8420+                 for  (int64_t  h = h_start; h < h_end; h++) {
8421+                     int64_t  h_offset = h * h_stride;
8422+                     int64_t  t_h_offset = t_offset + h_offset;
8423+                     int64_t  h_2d_offset = h * h_stride_2d;
8424+ 
8425+                     for  (int64_t  i = 0 ; i < head_size; i++) {
8426+                         int64_t  t_h_i_offset = t_h_offset + i;
8427+                         int64_t  h_2d_i_offset = h_2d_offset + i * h_stride;
8428+ 
8429+                         float  v_val = v[t_h_i_offset];
8430+ 
8431+                         float  sa = 0 , result = 0 ;
8432+                         for  (int64_t  j = 0 ; j < head_size; j++) {
8433+                             sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
8434+                         }
84288435
8429-                     float  sa = 0 ;
8430-                     {
8431-                         GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8432-                         GGML_F32_VEC ax[GGML_F32_ARR];
8433-                         GGML_F32_VEC ay[GGML_F32_ARR];
8434-                         for  (int64_t  j = 0 ; j < head_size; j += GGML_F32_STEP) {
8435-                             for  (int64_t  kk = 0 ; kk < GGML_F32_ARR; kk++) {
8436-                                 ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
8437-                                 ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8438-                                 sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
8439-                             }
8436+                         for  (int64_t  j = 0 ; j < head_size; j++) {
8437+                             int64_t  t_h_j_offset = t_h_offset + j;
8438+                             int64_t  h_2d_i_j_offset = h_2d_i_offset + j;
8439+ 
8440+                             float  r_val = r[t_h_j_offset];
8441+                             float  w_val = w[t_h_j_offset];
8442+                             float  k_val = k[t_h_j_offset];
8443+                             float  b_val = b[t_h_j_offset];
8444+                             float  kv_val = v_val * k_val;
8445+                             float  prev_state_val = state_prev[h_2d_i_j_offset];
8446+                             state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8447+                             result += state_cur[h_2d_i_j_offset] * r_val;
84408448                        }
8441-                         GGML_F32_VEC_REDUCE (sa, sum) ;
8449+                         dst_data[t_h_i_offset] = result ;
84428450                    }
8451+                 }
8452+             }
8453+         #else 
8454+             for  (int64_t  t = 0 ; t < T; t++) {
8455+                 int64_t  t_offset = t * t_stride;
8456+                 int64_t  state_offset = head_size * C * (t / (T / n_seqs));
8457+                 float  * state_cur = state + state_offset;
8458+                 float  * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data  + state_offset;
8459+ 
8460+                 for  (int64_t  h = h_start; h < h_end; h++) {
8461+                     int64_t  h_offset = h * h_stride;
8462+                     int64_t  t_h_offset = t_offset + h_offset;
8463+                     int64_t  h_2d_offset = h * h_stride_2d;
8464+ 
8465+                     for  (int64_t  ii = 0 ; ii < head_size; ii++) {
8466+                         int64_t  t_h_i_offset = t_h_offset + ii;
8467+                         int64_t  h_2d_i_offset = h_2d_offset + ii * h_stride;
8468+ 
8469+                         GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
8470+ 
8471+                         float  sa = 0 ;
8472+                         {
8473+                             GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8474+                             GGML_F32_VEC ax[GGML_F32_ARR];
8475+                             GGML_F32_VEC ay[GGML_F32_ARR];
8476+                             for  (int64_t  j = 0 ; j < head_size; j += GGML_F32_STEP) {
8477+                                 for  (int64_t  kk = 0 ; kk < GGML_F32_ARR; kk++) {
8478+                                     ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
8479+                                     ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8480+                                     sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
8481+                                 }
8482+                             }
8483+                             GGML_F32_VEC_REDUCE (sa, sum);
8484+                         }
84438485
8444-                     GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
8486+                          GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
84458487
8446-                     int64_t  j = 0 ;
8447-                     GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8448-                     for  (; j < head_size; j += GGML_F32_STEP) {
8449-                         for  (int64_t  kk = 0 ; kk < GGML_F32_ARR; kk++) {
8450-                             int64_t  t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8451-                             int64_t  h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
8488+                          int64_t  j = 0 ;
8489+                          GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8490+                          for  (; j < head_size; j += GGML_F32_STEP) {
8491+                              for  (int64_t  kk = 0 ; kk < GGML_F32_ARR; kk++) {
8492+                                  int64_t  t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8493+                                  int64_t  h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
84528494
8453-                             GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
8454-                             GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
8455-                             GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
8456-                             GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
8495+                                  GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
8496+                                  GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
8497+                                  GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
8498+                                  GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
84578499
8458-                             k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
8500+                                  k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
84598501
8460-                             GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
8461-                             //  kv + s * decay + sa * b
8462-                             state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
8463-                             state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
8464-                             GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
8502+                                  GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
8503+                                  //  kv + s * decay + sa * b
8504+                                  state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
8505+                                  state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
8506+                                  GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
84658507
8466-                             result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
8508+                                 result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
8509+                             }
8510+                         }
8511+                         GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
8512+ 
8513+                         //  There shouldn't be left-overs though.
8514+                         for  (; j < head_size; j++) {
8515+                             int64_t  t_h_j_offset = t_h_offset + j;
8516+                             int64_t  h_2d_i_j_offset = h_2d_i_offset + j;
8517+ 
8518+                             float  r_val = r[t_h_j_offset];
8519+                             float  w_val = w[t_h_j_offset];
8520+                             float  k_val = k[t_h_j_offset];
8521+                             float  b_val = b[t_h_j_offset];
8522+                             float  kv_val = v[t_h_i_offset] * k_val;
8523+ 
8524+                             float  prev_state_val = state_prev[h_2d_i_j_offset];
8525+                             state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8526+                             dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
84678527                        }
8468-                     }
8469-                     GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
8470- 
8471-                     //  There shouldn't be left-overs though.
8472-                     for  (; j < head_size; j++) {
8473-                         int64_t  t_h_j_offset = t_h_offset + j;
8474-                         int64_t  h_2d_i_j_offset = h_2d_i_offset + j;
8475- 
8476-                         float  r_val = r[t_h_j_offset];
8477-                         float  w_val = w[t_h_j_offset];
8478-                         float  k_val = k[t_h_j_offset];
8479-                         float  b_val = b[t_h_j_offset];
8480-                         float  kv_val = v[t_h_i_offset] * k_val;
8481- 
8482-                         float  prev_state_val = state_prev[h_2d_i_j_offset];
8483-                         state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8484-                         dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
84858528                    }
84868529                }
84878530            }
8488-         } 
8531+         # endif 
84898532    #else 
84908533        for  (int64_t  t = 0 ; t < T; t++) {
84918534            int64_t  t_offset = t * t_stride;
0 commit comments