@@ -8646,41 +8646,7 @@ static void ggml_compute_forward_ssm_scan_f32(
86468646 const int ii = i1 + h*nr;
86478647 const float x_dt = x[ii] * dt_soft_plus;
86488648 float sumf = 0 .0f ;
8649- #if defined(GGML_SIMD)
8650- #if defined(__ARM_FEATURE_SVE)
8651- const int ggml_f32_epr = svcntw ();
8652- const int ggml_f32_step = 1 * ggml_f32_epr;
8653-
8654- const int np = (nc & ~(ggml_f32_step - 1 ));
8655-
8656- GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8657-
8658- GGML_F32_VEC adA = GGML_F32_VEC_SET1 (dA);
8659- GGML_F32_VEC axdt = GGML_F32_VEC_SET1 (x_dt);
8660-
8661- for (int i = 0 ; i < np; i += ggml_f32_step) {
8662- // TODO: maybe unroll more?
8663- for (int j = 0 ; j < 1 ; j++) {
8664- GGML_F32_VEC t0 = GGML_F32_VEC_LOAD (s0 + i + j*ggml_f32_epr + ii*nc);
8665- GGML_F32_VEC t1 = GGML_F32_VEC_LOAD (B + i + j*ggml_f32_epr + g*nc);
8666- GGML_F32_VEC t2 = GGML_F32_VEC_LOAD (C + i + j*ggml_f32_epr + g*nc);
8667-
8668- t0 = GGML_F32_VEC_MUL (t0, adA);
8669- t1 = GGML_F32_VEC_MUL (t1, axdt);
8670-
8671- t0 = GGML_F32_VEC_ADD (t0, t1);
8672-
8673- sum = GGML_F32_VEC_FMA (sum, t0, t2);
8674-
8675- GGML_F32_VEC_STORE (s + i + j*ggml_f32_epr + ii*nc, t0);
8676- }
8677- }
8678-
8679- sumf = GGML_F32xt_REDUCE_ONE (sum);
8680- #elif defined(__riscv_v_intrinsic)
8681- // todo: RVV implementation
8682- const int np = 0 ;
8683- #else
8649+ #if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
86848650 const int np = (nc & ~(GGML_F32_STEP - 1 ));
86858651
86868652 GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
@@ -8711,7 +8677,6 @@ static void ggml_compute_forward_ssm_scan_f32(
87118677
87128678 // reduce sum0..sum3 to sum0
87138679 GGML_F32_VEC_REDUCE (sumf, sum);
8714- #endif
87158680#else
87168681 const int np = 0 ;
87178682#endif
@@ -8741,30 +8706,6 @@ static void ggml_compute_forward_ssm_scan_f32(
87418706 for (int i1 = 0 ; i1 < nr; ++i1) {
87428707 const int ii = i1 + h*nr;
87438708 const float x_dt = x[ii] * dt_soft_plus;
8744- #if defined(__ARM_FEATURE_SVE)
8745- svfloat32_t vx_dt = GGML_F32_VEC_SET1 (x_dt);
8746- svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1 (dt_soft_plus);
8747- svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8748-
8749- // d_state
8750- // TODO: what happens when (d_state % svcntw()) != 0?
8751- for (int64_t k = 0 ; k < nc; k += svcntw ()) {
8752- svfloat32_t vA = GGML_F32_VEC_LOAD (&A[h*nc + k]);
8753- svfloat32_t vB = GGML_F32_VEC_LOAD (&B[k + g*nc]);
8754- svfloat32_t vC = GGML_F32_VEC_LOAD (&C[k + g*nc]);
8755- svfloat32_t vs0 = GGML_F32_VEC_LOAD (&s0[ii*nc + k]);
8756-
8757- svfloat32_t t1 = GGML_F32_VEC_MUL (vdt_soft_plus, vA);
8758- t1 = exp_ps_sve (svptrue_b32 (), t1);
8759- svfloat32_t t2 = GGML_F32_VEC_MUL (vx_dt, vB);
8760-
8761- vs0 = GGML_F32_VEC_FMA (t2, vs0, t1);
8762- r1_vector = GGML_F32_VEC_ADD (GGML_F32_VEC_MUL (vs0, vC), r1_vector);
8763-
8764- GGML_F32_VEC_STORE (&s[ii*nc + k], vs0);
8765- }
8766- y[ii] = GGML_F32xt_REDUCE_ONE (r1_vector);
8767- #else
87688709 float sumf = 0 .0f ;
87698710 // NOTE: can't really use GGML_SIMD here because d_state is usually 16
87708711 // and also because expf is used within the loop.
@@ -8779,7 +8720,6 @@ static void ggml_compute_forward_ssm_scan_f32(
87798720 s[i] = state;
87808721 }
87818722 y[ii] = sumf;
8782- #endif
87838723 }
87848724 }
87858725 }
@@ -9231,14 +9171,6 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
92319171 #define GGML_F32X_MUL GGML_F32x16_MUL
92329172 #define GGML_F32X_FMA GGML_F32x16_FMA
92339173 #define WKV_VECTOR_SIZE 16
9234- #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
9235- #define GGML_F32X GGML_F32xt
9236- #define GGML_F32X_SET1 GGML_F32xt_SET1
9237- #define GGML_F32X_LOAD GGML_F32xt_LOAD
9238- #define GGML_F32X_STORE GGML_F32xt_STORE
9239- #define GGML_F32X_MUL GGML_F32xt_MUL
9240- #define GGML_F32X_FMA GGML_F32xt_FMA
9241- #define WKV_VECTOR_SIZE 8
92429174 #elif defined(__ARM_NEON) && defined(__aarch64__)
92439175 #define GGML_F32X GGML_F32x4
92449176 #define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -9251,11 +9183,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
92519183
92529184 #ifdef WKV_VECTOR_SIZE
92539185 int wkv_vector_size;
9254- #if defined(__ARM_FEATURE_SVE)
9255- wkv_vector_size = svcntw ();
9256- #else
9257- wkv_vector_size = WKV_VECTOR_SIZE;
9258- #endif
9186+ wkv_vector_size = WKV_VECTOR_SIZE;
92599187 const int64_t vec_count = head_size / wkv_vector_size;
92609188
92619189 for (int64_t t = 0 ; t < T; t++) {
@@ -9447,14 +9375,6 @@ static void ggml_compute_forward_gla_f32(
94479375 #define GGML_F32X_MUL GGML_F32x16_MUL
94489376 #define GGML_F32X_FMA GGML_F32x16_FMA
94499377 #define GLA_VECTOR_SIZE 16
9450- #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
9451- #define GGML_F32X GGML_F32xt
9452- #define GGML_F32X_SET1 GGML_F32xt_SET1
9453- #define GGML_F32X_LOAD GGML_F32xt_LOAD
9454- #define GGML_F32X_STORE GGML_F32xt_STORE
9455- #define GGML_F32X_MUL GGML_F32xt_MUL
9456- #define GGML_F32X_FMA GGML_F32xt_FMA
9457- #define GLA_VECTOR_SIZE 8
94589378 #elif defined(__ARM_NEON) && defined(__aarch64__)
94599379 #define GGML_F32X GGML_F32x4
94609380 #define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -9467,11 +9387,7 @@ static void ggml_compute_forward_gla_f32(
94679387
94689388 #ifdef GLA_VECTOR_SIZE
94699389 int gla_vector_size;
9470- #if defined(__ARM_FEATURE_SVE)
9471- gla_vector_size = svcntw ();
9472- #else
9473- gla_vector_size = GLA_VECTOR_SIZE;
9474- #endif
9390+ gla_vector_size = GLA_VECTOR_SIZE;
94759391 const int64_t vec_count = head_size / gla_vector_size;
94769392
94779393 for (int64_t t = 0 ; t < T; t++) {
@@ -9631,127 +9547,84 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
96319547 GGML_ASSERT (C % HEADS == 0 ); // C must be divisible by HEADS
96329548 int64_t h_stride_2d = head_size * head_size;
96339549
9634- #if defined(GGML_SIMD)
9635- #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
9636- // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
9637- for (int64_t t = 0 ; t < T; t++) {
9638- int64_t t_offset = t * t_stride;
9639- int64_t state_offset = head_size * C * (t / (T / n_seqs));
9640- float * state_cur = state + state_offset;
9641- float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
9642-
9643- for (int64_t h = h_start; h < h_end; h++) {
9644- int64_t h_offset = h * h_stride;
9645- int64_t t_h_offset = t_offset + h_offset;
9646- int64_t h_2d_offset = h * h_stride_2d;
9647-
9648- for (int64_t i = 0 ; i < head_size; i++) {
9649- int64_t t_h_i_offset = t_h_offset + i;
9650- int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
9651-
9652- float v_val = v[t_h_i_offset];
9653-
9654- float sa = 0 , result = 0 ;
9655- for (int64_t j = 0 ; j < head_size; j++) {
9656- sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
9657- }
9550+ #if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
9551+ for (int64_t t = 0 ; t < T; t++) {
9552+ int64_t t_offset = t * t_stride;
9553+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
9554+ float * state_cur = state + state_offset;
9555+ float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
96589556
9659- for (int64_t j = 0 ; j < head_size; j++) {
9660- int64_t t_h_j_offset = t_h_offset + j;
9661- int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9662-
9663- float r_val = r[t_h_j_offset];
9664- float w_val = w[t_h_j_offset];
9665- float k_val = k[t_h_j_offset];
9666- float b_val = b[t_h_j_offset];
9667- float kv_val = v_val * k_val;
9668- float prev_state_val = state_prev[h_2d_i_j_offset];
9669- state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9670- result += state_cur[h_2d_i_j_offset] * r_val;
9671- }
9672- dst_data[t_h_i_offset] = result;
9673- }
9674- }
9675- }
9676- #else
9677- for (int64_t t = 0 ; t < T; t++) {
9678- int64_t t_offset = t * t_stride;
9679- int64_t state_offset = head_size * C * (t / (T / n_seqs));
9680- float * state_cur = state + state_offset;
9681- float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
9682-
9683- for (int64_t h = h_start; h < h_end; h++) {
9684- int64_t h_offset = h * h_stride;
9685- int64_t t_h_offset = t_offset + h_offset;
9686- int64_t h_2d_offset = h * h_stride_2d;
9687-
9688- for (int64_t ii = 0 ; ii < head_size; ii++) {
9689- int64_t t_h_i_offset = t_h_offset + ii;
9690- int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
9691-
9692- GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
9693-
9694- float sa = 0 ;
9695- {
9696- GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9697- GGML_F32_VEC ax[GGML_F32_ARR];
9698- GGML_F32_VEC ay[GGML_F32_ARR];
9699- for (int64_t j = 0 ; j < head_size; j += GGML_F32_STEP) {
9700- for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
9701- ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
9702- ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
9703- sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
9704- }
9557+ for (int64_t h = h_start; h < h_end; h++) {
9558+ int64_t h_offset = h * h_stride;
9559+ int64_t t_h_offset = t_offset + h_offset;
9560+ int64_t h_2d_offset = h * h_stride_2d;
9561+
9562+ for (int64_t ii = 0 ; ii < head_size; ii++) {
9563+ int64_t t_h_i_offset = t_h_offset + ii;
9564+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
9565+
9566+ GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
9567+
9568+ float sa = 0 ;
9569+ {
9570+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9571+ GGML_F32_VEC ax[GGML_F32_ARR];
9572+ GGML_F32_VEC ay[GGML_F32_ARR];
9573+ for (int64_t j = 0 ; j < head_size; j += GGML_F32_STEP) {
9574+ for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
9575+ ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
9576+ ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
9577+ sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
97059578 }
9706- GGML_F32_VEC_REDUCE (sa, sum);
97079579 }
9580+ GGML_F32_VEC_REDUCE (sa, sum);
9581+ }
97089582
9709- GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
9583+ GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
97109584
9711- int64_t j = 0 ;
9712- GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9713- for (; j < head_size; j += GGML_F32_STEP) {
9714- for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
9715- int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
9716- int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
9585+ int64_t j = 0 ;
9586+ GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9587+ for (; j < head_size; j += GGML_F32_STEP) {
9588+ for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
9589+ int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
9590+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
97179591
9718- GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
9719- GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
9720- GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
9721- GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
9592+ GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
9593+ GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
9594+ GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
9595+ GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
97229596
9723- k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
9597+ k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
97249598
9725- GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
9726- // kv + s * decay + sa * b
9727- state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
9728- state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
9729- GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
9599+ GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
9600+ // kv + s * decay + sa * b
9601+ state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
9602+ state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
9603+ GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
97309604
9731- result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
9732- }
9733- }
9734- GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
9735-
9736- // There shouldn't be left-overs though.
9737- for (; j < head_size; j++) {
9738- int64_t t_h_j_offset = t_h_offset + j;
9739- int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9740-
9741- float r_val = r[t_h_j_offset];
9742- float w_val = w[t_h_j_offset];
9743- float k_val = k[t_h_j_offset];
9744- float b_val = b[t_h_j_offset];
9745- float kv_val = v[t_h_i_offset] * k_val;
9746-
9747- float prev_state_val = state_prev[h_2d_i_j_offset];
9748- state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9749- dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
9605+ result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
97509606 }
97519607 }
9608+ GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
9609+
9610+ // There shouldn't be left-overs though.
9611+ for (; j < head_size; j++) {
9612+ int64_t t_h_j_offset = t_h_offset + j;
9613+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9614+
9615+ float r_val = r[t_h_j_offset];
9616+ float w_val = w[t_h_j_offset];
9617+ float k_val = k[t_h_j_offset];
9618+ float b_val = b[t_h_j_offset];
9619+ float kv_val = v[t_h_i_offset] * k_val;
9620+
9621+ float prev_state_val = state_prev[h_2d_i_j_offset];
9622+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9623+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
9624+ }
97529625 }
97539626 }
9754- # endif
9627+ }
97559628 #else
97569629 for (int64_t t = 0 ; t < T; t++) {
97579630 int64_t t_offset = t * t_stride;
0 commit comments