@@ -7641,8 +7641,8 @@ static void ggml_compute_forward_ssm_scan_f32(
76417641 const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb [1 ])); // {d_state, d_inner}
76427642 const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb [1 ]) + i3*(src4->nb [2 ])); // {d_state, n_t, n_s}
76437643 const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb [1 ]) + i3*(src5->nb [2 ])); // {d_state, n_t, n_s}
7644- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7645- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
7644+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7645+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
76467646
76477647 // use the output as the source for the next token-wise iterations
76487648 if (i2 > 0 ) { s0 = s; }
@@ -8070,6 +8070,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80708070 #define GGML_F32X_MUL GGML_F32x16_MUL
80718071 #define GGML_F32X_FMA GGML_F32x16_FMA
80728072 #define WKV_VECTOR_SIZE 16
8073+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8074+ #define GGML_F32X GGML_F32xt
8075+ #define GGML_F32X_SET1 GGML_F32xt_SET1
8076+ #define GGML_F32X_LOAD GGML_F32xt_LOAD
8077+ #define GGML_F32X_STORE GGML_F32xt_STORE
8078+ #define GGML_F32X_MUL GGML_F32xt_MUL
8079+ #define GGML_F32X_FMA GGML_F32xt_FMA
8080+ #define WKV_VECTOR_SIZE 8
80738081 #elif defined(__ARM_NEON) && defined(__aarch64__)
80748082 #define GGML_F32X GGML_F32x4
80758083 #define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -8080,8 +8088,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80808088 #define WKV_VECTOR_SIZE 4
80818089 #endif
80828090
8091+ int wkv_vector_size;
80838092 #ifdef WKV_VECTOR_SIZE
8084- const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
8093+ #if defined(__ARM_FEATURE_SVE)
8094+ wkv_vector_size = svcntw ();
8095+ #else
8096+ wkv_vector_size = WKV_VECTOR_SIZE;
8097+ #endif
8098+ const int64_t vec_count = head_size / wkv_vector_size;
80858099
80868100 for (int64_t t = 0 ; t < T; t++) {
80878101 size_t t_offset = t * t_stride;
@@ -8111,7 +8125,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
81118125 GGML_F32X time_decay_vec = GGML_F32X_SET1 (time_decay_val);
81128126
81138127 for (int64_t j = 0 ; j < vec_count; j++) {
8114- size_t base_j = j * WKV_VECTOR_SIZE ;
8128+ size_t base_j = j * wkv_vector_size ;
81158129 size_t t_h_j_offset = t_h_offset + base_j;
81168130 size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
81178131
@@ -8136,7 +8150,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
81368150 }
81378151
81388152 // Handle remaining elements, this will not be used.
8139- for (int64_t j = vec_count * WKV_VECTOR_SIZE ; j < head_size; j++) {
8153+ for (int64_t j = vec_count * wkv_vector_size ; j < head_size; j++) {
81408154 size_t t_h_j_offset = t_h_offset + j;
81418155 size_t h_2d_i_j_offset = h_2d_i_offset + j;
81428156 float v_val = v[t_h_j_offset];
@@ -8272,6 +8286,14 @@ static void ggml_compute_forward_gla_f32(
82728286 #define GGML_F32X_MUL GGML_F32x16_MUL
82738287 #define GGML_F32X_FMA GGML_F32x16_FMA
82748288 #define GLA_VECTOR_SIZE 16
8289+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8290+ #define GGML_F32X GGML_F32xt
8291+ #define GGML_F32X_SET1 GGML_F32xt_SET1
8292+ #define GGML_F32X_LOAD GGML_F32xt_LOAD
8293+ #define GGML_F32X_STORE GGML_F32xt_STORE
8294+ #define GGML_F32X_MUL GGML_F32xt_MUL
8295+ #define GGML_F32X_FMA GGML_F32xt_FMA
8296+ #define GLA_VECTOR_SIZE 8
82758297 #elif defined(__ARM_NEON) && defined(__aarch64__)
82768298 #define GGML_F32X GGML_F32x4
82778299 #define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -8282,8 +8304,14 @@ static void ggml_compute_forward_gla_f32(
82828304 #define GLA_VECTOR_SIZE 4
82838305 #endif
82848306
8307+ int gla_vector_size;
82858308 #ifdef GLA_VECTOR_SIZE
8286- const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
8309+ #if defined(__ARM_FEATURE_SVE)
8310+ gla_vector_size = svcntw ();
8311+ #else
8312+ gla_vector_size = GLA_VECTOR_SIZE;
8313+ #endif
8314+ const int64_t vec_count = head_size / gla_vector_size;
82878315
82888316 for (int64_t t = 0 ; t < T; t++) {
82898317 size_t t_offset = t * t_stride;
@@ -8310,7 +8338,7 @@ static void ggml_compute_forward_gla_f32(
83108338 GGML_F32X g_vec = GGML_F32X_SET1 (g_val);
83118339
83128340 for (int64_t j = 0 ; j < vec_count; j++) {
8313- size_t base_j = j * GLA_VECTOR_SIZE ;
8341+ size_t base_j = j * gla_vector_size ;
83148342 size_t t_h_j_offset = t_h_offset + base_j;
83158343 size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
83168344
@@ -8334,7 +8362,7 @@ static void ggml_compute_forward_gla_f32(
83348362 }
83358363
83368364 // Handle remaining elements, this will not be used.
8337- for (int64_t j = vec_count * GLA_VECTOR_SIZE ; j < head_size; j++) {
8365+ for (int64_t j = vec_count * gla_vector_size ; j < head_size; j++) {
83388366 size_t t_h_j_offset = t_h_offset + j;
83398367 size_t h_2d_i_j_offset = h_2d_i_offset + j;
83408368 float v_val = v[t_h_j_offset];
@@ -8443,83 +8471,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
84438471 int64_t h_stride_2d = head_size * head_size;
84448472
84458473 #if defined(GGML_SIMD)
8446- for (int64_t t = 0 ; t < T; t++) {
8447- int64_t t_offset = t * t_stride;
8448- int64_t state_offset = head_size * C * (t / (T / n_seqs));
8449- float * state_cur = state + state_offset;
8450- float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
8451-
8452- for (int64_t h = h_start; h < h_end; h++) {
8453- int64_t h_offset = h * h_stride;
8454- int64_t t_h_offset = t_offset + h_offset;
8455- int64_t h_2d_offset = h * h_stride_2d;
8456-
8457- for (int64_t ii = 0 ; ii < head_size; ii++) {
8458- int64_t t_h_i_offset = t_h_offset + ii;
8459- int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8460-
8461- GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
8474+ #if defined(__ARM_FEATURE_SVE)
8475+ // scalar Route to scalar implementation //TODO: Write SVE code
8476+ for (int64_t t = 0 ; t < T; t++) {
8477+ int64_t t_offset = t * t_stride;
8478+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
8479+ float * state_cur = state + state_offset;
8480+ float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
8481+
8482+ for (int64_t h = h_start; h < h_end; h++) {
8483+ int64_t h_offset = h * h_stride;
8484+ int64_t t_h_offset = t_offset + h_offset;
8485+ int64_t h_2d_offset = h * h_stride_2d;
8486+
8487+ for (int64_t i = 0 ; i < head_size; i++) {
8488+ int64_t t_h_i_offset = t_h_offset + i;
8489+ int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
8490+
8491+ float v_val = v[t_h_i_offset];
8492+
8493+ float sa = 0 , result = 0 ;
8494+ for (int64_t j = 0 ; j < head_size; j++) {
8495+ sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
8496+ }
84628497
8463- float sa = 0 ;
8464- {
8465- GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8466- GGML_F32_VEC ax[GGML_F32_ARR];
8467- GGML_F32_VEC ay[GGML_F32_ARR];
8468- for (int64_t j = 0 ; j < head_size; j += GGML_F32_STEP) {
8469- for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
8470- ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
8471- ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8472- sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
8473- }
8498+ for (int64_t j = 0 ; j < head_size; j++) {
8499+ int64_t t_h_j_offset = t_h_offset + j;
8500+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8501+
8502+ float r_val = r[t_h_j_offset];
8503+ float w_val = w[t_h_j_offset];
8504+ float k_val = k[t_h_j_offset];
8505+ float b_val = b[t_h_j_offset];
8506+ float kv_val = v_val * k_val;
8507+ float prev_state_val = state_prev[h_2d_i_j_offset];
8508+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8509+ result += state_cur[h_2d_i_j_offset] * r_val;
84748510 }
8475- GGML_F32_VEC_REDUCE (sa, sum) ;
8511+ dst_data[t_h_i_offset] = result ;
84768512 }
8513+ }
8514+ }
8515+ #else
8516+ for (int64_t t = 0 ; t < T; t++) {
8517+ int64_t t_offset = t * t_stride;
8518+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
8519+ float * state_cur = state + state_offset;
8520+ float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
8521+
8522+ for (int64_t h = h_start; h < h_end; h++) {
8523+ int64_t h_offset = h * h_stride;
8524+ int64_t t_h_offset = t_offset + h_offset;
8525+ int64_t h_2d_offset = h * h_stride_2d;
8526+
8527+ for (int64_t ii = 0 ; ii < head_size; ii++) {
8528+ int64_t t_h_i_offset = t_h_offset + ii;
8529+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8530+
8531+ GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
8532+
8533+ float sa = 0 ;
8534+ {
8535+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8536+ GGML_F32_VEC ax[GGML_F32_ARR];
8537+ GGML_F32_VEC ay[GGML_F32_ARR];
8538+ for (int64_t j = 0 ; j < head_size; j += GGML_F32_STEP) {
8539+ for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
8540+ ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
8541+ ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8542+ sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
8543+ }
8544+ }
8545+ GGML_F32_VEC_REDUCE (sa, sum);
8546+ }
84778547
8478- GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
8548+ GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
84798549
8480- int64_t j = 0 ;
8481- GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8482- for (; j < head_size; j += GGML_F32_STEP) {
8483- for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
8484- int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8485- int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
8550+ int64_t j = 0 ;
8551+ GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8552+ for (; j < head_size; j += GGML_F32_STEP) {
8553+ for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
8554+ int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8555+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
84868556
8487- GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
8488- GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
8489- GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
8490- GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
8557+ GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
8558+ GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
8559+ GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
8560+ GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
84918561
8492- k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
8562+ k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
84938563
8494- GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
8495- // kv + s * decay + sa * b
8496- state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
8497- state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
8498- GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
8564+ GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
8565+ // kv + s * decay + sa * b
8566+ state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
8567+ state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
8568+ GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
84998569
8500- result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
8570+ result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
8571+ }
8572+ }
8573+ GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
8574+
8575+ // There shouldn't be left-overs though.
8576+ for (; j < head_size; j++) {
8577+ int64_t t_h_j_offset = t_h_offset + j;
8578+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8579+
8580+ float r_val = r[t_h_j_offset];
8581+ float w_val = w[t_h_j_offset];
8582+ float k_val = k[t_h_j_offset];
8583+ float b_val = b[t_h_j_offset];
8584+ float kv_val = v[t_h_i_offset] * k_val;
8585+
8586+ float prev_state_val = state_prev[h_2d_i_j_offset];
8587+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8588+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
85018589 }
8502- }
8503- GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
8504-
8505- // There shouldn't be left-overs though.
8506- for (; j < head_size; j++) {
8507- int64_t t_h_j_offset = t_h_offset + j;
8508- int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8509-
8510- float r_val = r[t_h_j_offset];
8511- float w_val = w[t_h_j_offset];
8512- float k_val = k[t_h_j_offset];
8513- float b_val = b[t_h_j_offset];
8514- float kv_val = v[t_h_i_offset] * k_val;
8515-
8516- float prev_state_val = state_prev[h_2d_i_j_offset];
8517- state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8518- dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
85198590 }
85208591 }
85218592 }
8522- }
8593+ # endif
85238594 #else
85248595 for (int64_t t = 0 ; t < T; t++) {
85258596 int64_t t_offset = t * t_stride;
0 commit comments