@@ -7963,6 +7963,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
79637963 #define GGML_F32X_MUL GGML_F32x16_MUL
79647964 #define GGML_F32X_FMA GGML_F32x16_FMA
79657965 #define WKV_VECTOR_SIZE 16
7966+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
7967+ #define GGML_F32X GGML_F32xt
7968+ #define GGML_F32X_SET1 GGML_F32xt_SET1
7969+ #define GGML_F32X_LOAD GGML_F32xt_LOAD
7970+ #define GGML_F32X_STORE GGML_F32xt_STORE
7971+ #define GGML_F32X_MUL GGML_F32xt_MUL
7972+ #define GGML_F32X_FMA GGML_F32xt_FMA
7973+ #define WKV_VECTOR_SIZE 8
79667974 #elif defined(__ARM_NEON) && defined(__aarch64__)
79677975 #define GGML_F32X GGML_F32x4
79687976 #define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -7973,8 +7981,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
79737981 #define WKV_VECTOR_SIZE 4
79747982 #endif
79757983
7984+ int wkv_vector_size;
79767985 #ifdef WKV_VECTOR_SIZE
7977- const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
7986+ #if defined(__ARM_FEATURE_SVE)
7987+ wkv_vector_size = svcntw ();
7988+ #else
7989+ wkv_vector_size = WKV_VECTOR_SIZE;
7990+ #endif
7991+ const int64_t vec_count = head_size / wkv_vector_size;
79787992
79797993 for (int64_t t = 0 ; t < T; t++) {
79807994 size_t t_offset = t * t_stride;
@@ -8004,7 +8018,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80048018 GGML_F32X time_decay_vec = GGML_F32X_SET1 (time_decay_val);
80058019
80068020 for (int64_t j = 0 ; j < vec_count; j++) {
8007- size_t base_j = j * WKV_VECTOR_SIZE ;
8021+ size_t base_j = j * wkv_vector_size ;
80088022 size_t t_h_j_offset = t_h_offset + base_j;
80098023 size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
80108024
@@ -8029,7 +8043,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80298043 }
80308044
80318045 // Handle remaining elements, this will not be used.
8032- for (int64_t j = vec_count * WKV_VECTOR_SIZE ; j < head_size; j++) {
8046+ for (int64_t j = vec_count * wkv_vector_size ; j < head_size; j++) {
80338047 size_t t_h_j_offset = t_h_offset + j;
80348048 size_t h_2d_i_j_offset = h_2d_i_offset + j;
80358049 float v_val = v[t_h_j_offset];
@@ -8165,6 +8179,14 @@ static void ggml_compute_forward_gla_f32(
81658179 #define GGML_F32X_MUL GGML_F32x16_MUL
81668180 #define GGML_F32X_FMA GGML_F32x16_FMA
81678181 #define GLA_VECTOR_SIZE 16
8182+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8183+ #define GGML_F32X GGML_F32xt
8184+ #define GGML_F32X_SET1 GGML_F32xt_SET1
8185+ #define GGML_F32X_LOAD GGML_F32xt_LOAD
8186+ #define GGML_F32X_STORE GGML_F32xt_STORE
8187+ #define GGML_F32X_MUL GGML_F32xt_MUL
8188+ #define GGML_F32X_FMA GGML_F32xt_FMA
8189+ #define GLA_VECTOR_SIZE 8
81688190 #elif defined(__ARM_NEON) && defined(__aarch64__)
81698191 #define GGML_F32X GGML_F32x4
81708192 #define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -8175,8 +8197,14 @@ static void ggml_compute_forward_gla_f32(
81758197 #define GLA_VECTOR_SIZE 4
81768198 #endif
81778199
8200+ int gla_vector_size;
81788201 #ifdef GLA_VECTOR_SIZE
8179- const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
8202+ #if defined(__ARM_FEATURE_SVE)
8203+ gla_vector_size = svcntw ();
8204+ #else
8205+ gla_vector_size = GLA_VECTOR_SIZE;
8206+ #endif
8207+ const int64_t vec_count = head_size / gla_vector_size;
81808208
81818209 for (int64_t t = 0 ; t < T; t++) {
81828210 size_t t_offset = t * t_stride;
@@ -8203,7 +8231,7 @@ static void ggml_compute_forward_gla_f32(
82038231 GGML_F32X g_vec = GGML_F32X_SET1 (g_val);
82048232
82058233 for (int64_t j = 0 ; j < vec_count; j++) {
8206- size_t base_j = j * GLA_VECTOR_SIZE ;
8234+ size_t base_j = j * gla_vector_size ;
82078235 size_t t_h_j_offset = t_h_offset + base_j;
82088236 size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
82098237
@@ -8227,7 +8255,7 @@ static void ggml_compute_forward_gla_f32(
82278255 }
82288256
82298257 // Handle remaining elements, this will not be used.
8230- for (int64_t j = vec_count * GLA_VECTOR_SIZE ; j < head_size; j++) {
8258+ for (int64_t j = vec_count * gla_vector_size ; j < head_size; j++) {
82318259 size_t t_h_j_offset = t_h_offset + j;
82328260 size_t h_2d_i_j_offset = h_2d_i_offset + j;
82338261 float v_val = v[t_h_j_offset];
@@ -8336,83 +8364,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
83368364 int64_t h_stride_2d = head_size * head_size;
83378365
83388366 #if defined(GGML_SIMD)
8339- for (int64_t t = 0 ; t < T; t++) {
8340- int64_t t_offset = t * t_stride;
8341- int64_t state_offset = head_size * C * (t / (T / n_seqs));
8342- float * state_cur = state + state_offset;
8343- float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
8344-
8345- for (int64_t h = h_start; h < h_end; h++) {
8346- int64_t h_offset = h * h_stride;
8347- int64_t t_h_offset = t_offset + h_offset;
8348- int64_t h_2d_offset = h * h_stride_2d;
8349-
8350- for (int64_t ii = 0 ; ii < head_size; ii++) {
8351- int64_t t_h_i_offset = t_h_offset + ii;
8352- int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8353-
8354- GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
8367+ #if defined(__ARM_FEATURE_SVE)
8368+ // scalar Route to scalar implementation //TODO: Write SVE code
8369+ for (int64_t t = 0 ; t < T; t++) {
8370+ int64_t t_offset = t * t_stride;
8371+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
8372+ float * state_cur = state + state_offset;
8373+ float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
8374+
8375+ for (int64_t h = h_start; h < h_end; h++) {
8376+ int64_t h_offset = h * h_stride;
8377+ int64_t t_h_offset = t_offset + h_offset;
8378+ int64_t h_2d_offset = h * h_stride_2d;
8379+
8380+ for (int64_t i = 0 ; i < head_size; i++) {
8381+ int64_t t_h_i_offset = t_h_offset + i;
8382+ int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
8383+
8384+ float v_val = v[t_h_i_offset];
8385+
8386+ float sa = 0 , result = 0 ;
8387+ for (int64_t j = 0 ; j < head_size; j++) {
8388+ sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
8389+ }
83558390
8356- float sa = 0 ;
8357- {
8358- GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8359- GGML_F32_VEC ax[GGML_F32_ARR];
8360- GGML_F32_VEC ay[GGML_F32_ARR];
8361- for (int64_t j = 0 ; j < head_size; j += GGML_F32_STEP) {
8362- for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
8363- ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
8364- ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8365- sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
8366- }
8391+ for (int64_t j = 0 ; j < head_size; j++) {
8392+ int64_t t_h_j_offset = t_h_offset + j;
8393+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8394+
8395+ float r_val = r[t_h_j_offset];
8396+ float w_val = w[t_h_j_offset];
8397+ float k_val = k[t_h_j_offset];
8398+ float b_val = b[t_h_j_offset];
8399+ float kv_val = v_val * k_val;
8400+ float prev_state_val = state_prev[h_2d_i_j_offset];
8401+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8402+ result += state_cur[h_2d_i_j_offset] * r_val;
83678403 }
8368- GGML_F32_VEC_REDUCE (sa, sum) ;
8404+ dst_data[t_h_i_offset] = result ;
83698405 }
8406+ }
8407+ }
8408+ #else
8409+ for (int64_t t = 0 ; t < T; t++) {
8410+ int64_t t_offset = t * t_stride;
8411+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
8412+ float * state_cur = state + state_offset;
8413+ float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
8414+
8415+ for (int64_t h = h_start; h < h_end; h++) {
8416+ int64_t h_offset = h * h_stride;
8417+ int64_t t_h_offset = t_offset + h_offset;
8418+ int64_t h_2d_offset = h * h_stride_2d;
8419+
8420+ for (int64_t ii = 0 ; ii < head_size; ii++) {
8421+ int64_t t_h_i_offset = t_h_offset + ii;
8422+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8423+
8424+ GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
8425+
8426+ float sa = 0 ;
8427+ {
8428+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8429+ GGML_F32_VEC ax[GGML_F32_ARR];
8430+ GGML_F32_VEC ay[GGML_F32_ARR];
8431+ for (int64_t j = 0 ; j < head_size; j += GGML_F32_STEP) {
8432+ for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
8433+ ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
8434+ ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8435+ sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
8436+ }
8437+ }
8438+ GGML_F32_VEC_REDUCE (sa, sum);
8439+ }
83708440
8371- GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
8441+ GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
83728442
8373- int64_t j = 0 ;
8374- GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8375- for (; j < head_size; j += GGML_F32_STEP) {
8376- for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
8377- int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8378- int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
8443+ int64_t j = 0 ;
8444+ GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8445+ for (; j < head_size; j += GGML_F32_STEP) {
8446+ for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
8447+ int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8448+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
83798449
8380- GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
8381- GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
8382- GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
8383- GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
8450+ GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
8451+ GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
8452+ GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
8453+ GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
83848454
8385- k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
8455+ k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
83868456
8387- GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
8388- // kv + s * decay + sa * b
8389- state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
8390- state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
8391- GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
8457+ GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
8458+ // kv + s * decay + sa * b
8459+ state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
8460+ state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
8461+ GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
83928462
8393- result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
8463+ result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
8464+ }
8465+ }
8466+ GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
8467+
8468+ // There shouldn't be left-overs though.
8469+ for (; j < head_size; j++) {
8470+ int64_t t_h_j_offset = t_h_offset + j;
8471+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8472+
8473+ float r_val = r[t_h_j_offset];
8474+ float w_val = w[t_h_j_offset];
8475+ float k_val = k[t_h_j_offset];
8476+ float b_val = b[t_h_j_offset];
8477+ float kv_val = v[t_h_i_offset] * k_val;
8478+
8479+ float prev_state_val = state_prev[h_2d_i_j_offset];
8480+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8481+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
83948482 }
8395- }
8396- GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
8397-
8398- // There shouldn't be left-overs though.
8399- for (; j < head_size; j++) {
8400- int64_t t_h_j_offset = t_h_offset + j;
8401- int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8402-
8403- float r_val = r[t_h_j_offset];
8404- float w_val = w[t_h_j_offset];
8405- float k_val = k[t_h_j_offset];
8406- float b_val = b[t_h_j_offset];
8407- float kv_val = v[t_h_i_offset] * k_val;
8408-
8409- float prev_state_val = state_prev[h_2d_i_j_offset];
8410- state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8411- dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
84128483 }
84138484 }
84148485 }
8415- }
8486+ # endif
84168487 #else
84178488 for (int64_t t = 0 ; t < T; t++) {
84188489 int64_t t_offset = t * t_stride;
0 commit comments