@@ -8647,40 +8647,6 @@ static void ggml_compute_forward_ssm_scan_f32(
86478647 const float x_dt = x[ii] * dt_soft_plus;
86488648 float sumf = 0 .0f ;
86498649#if defined(GGML_SIMD)
8650- #if defined(__ARM_FEATURE_SVE)
8651- const int ggml_f32_epr = svcntw ();
8652- const int ggml_f32_step = 1 * ggml_f32_epr;
8653-
8654- const int np = (nc & ~(ggml_f32_step - 1 ));
8655-
8656- GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8657-
8658- GGML_F32_VEC adA = GGML_F32_VEC_SET1 (dA);
8659- GGML_F32_VEC axdt = GGML_F32_VEC_SET1 (x_dt);
8660-
8661- for (int i = 0 ; i < np; i += ggml_f32_step) {
8662- // TODO: maybe unroll more?
8663- for (int j = 0 ; j < 1 ; j++) {
8664- GGML_F32_VEC t0 = GGML_F32_VEC_LOAD (s0 + i + j*ggml_f32_epr + ii*nc);
8665- GGML_F32_VEC t1 = GGML_F32_VEC_LOAD (B + i + j*ggml_f32_epr + g*nc);
8666- GGML_F32_VEC t2 = GGML_F32_VEC_LOAD (C + i + j*ggml_f32_epr + g*nc);
8667-
8668- t0 = GGML_F32_VEC_MUL (t0, adA);
8669- t1 = GGML_F32_VEC_MUL (t1, axdt);
8670-
8671- t0 = GGML_F32_VEC_ADD (t0, t1);
8672-
8673- sum = GGML_F32_VEC_FMA (sum, t0, t2);
8674-
8675- GGML_F32_VEC_STORE (s + i + j*ggml_f32_epr + ii*nc, t0);
8676- }
8677- }
8678-
8679- sumf = GGML_F32xt_REDUCE_ONE (sum);
8680- #elif defined(__riscv_v_intrinsic)
8681- // todo: RVV implementation
8682- const int np = 0 ;
8683- #else
86848650 const int np = (nc & ~(GGML_F32_STEP - 1 ));
86858651
86868652 GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
@@ -8711,7 +8677,6 @@ static void ggml_compute_forward_ssm_scan_f32(
87118677
87128678 // reduce sum0..sum3 to sum0
87138679 GGML_F32_VEC_REDUCE (sumf, sum);
8714- #endif
87158680#else
87168681 const int np = 0 ;
87178682#endif
@@ -8741,30 +8706,6 @@ static void ggml_compute_forward_ssm_scan_f32(
87418706 for (int i1 = 0 ; i1 < nr; ++i1) {
87428707 const int ii = i1 + h*nr;
87438708 const float x_dt = x[ii] * dt_soft_plus;
8744- #if defined(__ARM_FEATURE_SVE)
8745- svfloat32_t vx_dt = GGML_F32_VEC_SET1 (x_dt);
8746- svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1 (dt_soft_plus);
8747- svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8748-
8749- // d_state
8750- // TODO: what happens when (d_state % svcntw()) != 0?
8751- for (int64_t k = 0 ; k < nc; k += svcntw ()) {
8752- svfloat32_t vA = GGML_F32_VEC_LOAD (&A[h*nc + k]);
8753- svfloat32_t vB = GGML_F32_VEC_LOAD (&B[k + g*nc]);
8754- svfloat32_t vC = GGML_F32_VEC_LOAD (&C[k + g*nc]);
8755- svfloat32_t vs0 = GGML_F32_VEC_LOAD (&s0[ii*nc + k]);
8756-
8757- svfloat32_t t1 = GGML_F32_VEC_MUL (vdt_soft_plus, vA);
8758- t1 = exp_ps_sve (svptrue_b32 (), t1);
8759- svfloat32_t t2 = GGML_F32_VEC_MUL (vx_dt, vB);
8760-
8761- vs0 = GGML_F32_VEC_FMA (t2, vs0, t1);
8762- r1_vector = GGML_F32_VEC_ADD (GGML_F32_VEC_MUL (vs0, vC), r1_vector);
8763-
8764- GGML_F32_VEC_STORE (&s[ii*nc + k], vs0);
8765- }
8766- y[ii] = GGML_F32xt_REDUCE_ONE (r1_vector);
8767- #else
87688709 float sumf = 0 .0f ;
87698710 // NOTE: can't really use GGML_SIMD here because d_state is usually 16
87708711 // and also because expf is used within the loop.
@@ -8779,7 +8720,6 @@ static void ggml_compute_forward_ssm_scan_f32(
87798720 s[i] = state;
87808721 }
87818722 y[ii] = sumf;
8782- #endif
87838723 }
87848724 }
87858725 }
@@ -9632,126 +9572,83 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
96329572 int64_t h_stride_2d = head_size * head_size;
96339573
96349574 #if defined(GGML_SIMD)
9635- #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
9636- // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
9637- for (int64_t t = 0 ; t < T; t++) {
9638- int64_t t_offset = t * t_stride;
9639- int64_t state_offset = head_size * C * (t / (T / n_seqs));
9640- float * state_cur = state + state_offset;
9641- float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
9642-
9643- for (int64_t h = h_start; h < h_end; h++) {
9644- int64_t h_offset = h * h_stride;
9645- int64_t t_h_offset = t_offset + h_offset;
9646- int64_t h_2d_offset = h * h_stride_2d;
9647-
9648- for (int64_t i = 0 ; i < head_size; i++) {
9649- int64_t t_h_i_offset = t_h_offset + i;
9650- int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
9651-
9652- float v_val = v[t_h_i_offset];
9653-
9654- float sa = 0 , result = 0 ;
9655- for (int64_t j = 0 ; j < head_size; j++) {
9656- sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
9657- }
9575+ for (int64_t t = 0 ; t < T; t++) {
9576+ int64_t t_offset = t * t_stride;
9577+ int64_t state_offset = head_size * C * (t / (T / n_seqs));
9578+ float * state_cur = state + state_offset;
9579+ float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
96589580
9659- for (int64_t j = 0 ; j < head_size; j++) {
9660- int64_t t_h_j_offset = t_h_offset + j;
9661- int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9662-
9663- float r_val = r[t_h_j_offset];
9664- float w_val = w[t_h_j_offset];
9665- float k_val = k[t_h_j_offset];
9666- float b_val = b[t_h_j_offset];
9667- float kv_val = v_val * k_val;
9668- float prev_state_val = state_prev[h_2d_i_j_offset];
9669- state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9670- result += state_cur[h_2d_i_j_offset] * r_val;
9671- }
9672- dst_data[t_h_i_offset] = result;
9673- }
9674- }
9675- }
9676- #else
9677- for (int64_t t = 0 ; t < T; t++) {
9678- int64_t t_offset = t * t_stride;
9679- int64_t state_offset = head_size * C * (t / (T / n_seqs));
9680- float * state_cur = state + state_offset;
9681- float * state_prev = t % (T / n_seqs) ? state_cur : (float *)dst->src [6 ]->data + state_offset;
9682-
9683- for (int64_t h = h_start; h < h_end; h++) {
9684- int64_t h_offset = h * h_stride;
9685- int64_t t_h_offset = t_offset + h_offset;
9686- int64_t h_2d_offset = h * h_stride_2d;
9687-
9688- for (int64_t ii = 0 ; ii < head_size; ii++) {
9689- int64_t t_h_i_offset = t_h_offset + ii;
9690- int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
9691-
9692- GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
9693-
9694- float sa = 0 ;
9695- {
9696- GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9697- GGML_F32_VEC ax[GGML_F32_ARR];
9698- GGML_F32_VEC ay[GGML_F32_ARR];
9699- for (int64_t j = 0 ; j < head_size; j += GGML_F32_STEP) {
9700- for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
9701- ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
9702- ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
9703- sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
9704- }
9581+ for (int64_t h = h_start; h < h_end; h++) {
9582+ int64_t h_offset = h * h_stride;
9583+ int64_t t_h_offset = t_offset + h_offset;
9584+ int64_t h_2d_offset = h * h_stride_2d;
9585+
9586+ for (int64_t ii = 0 ; ii < head_size; ii++) {
9587+ int64_t t_h_i_offset = t_h_offset + ii;
9588+ int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
9589+
9590+ GGML_F32_VEC v_vec = GGML_F32_VEC_SET1 (v[t_h_i_offset]);
9591+
9592+ float sa = 0 ;
9593+ {
9594+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9595+ GGML_F32_VEC ax[GGML_F32_ARR];
9596+ GGML_F32_VEC ay[GGML_F32_ARR];
9597+ for (int64_t j = 0 ; j < head_size; j += GGML_F32_STEP) {
9598+ for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
9599+ ax[kk] = GGML_F32_VEC_LOAD (&a[t_h_offset + j + kk * GGML_F32_EPR]);
9600+ ay[kk] = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
9601+ sum[kk] = GGML_F32_VEC_FMA (sum[kk], ax[kk], ay[kk]);
97059602 }
9706- GGML_F32_VEC_REDUCE (sa, sum);
97079603 }
9604+ GGML_F32_VEC_REDUCE (sa, sum);
9605+ }
97089606
9709- GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
9607+ GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1 (sa);
97109608
9711- int64_t j = 0 ;
9712- GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9713- for (; j < head_size; j += GGML_F32_STEP) {
9714- for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
9715- int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
9716- int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
9609+ int64_t j = 0 ;
9610+ GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9611+ for (; j < head_size; j += GGML_F32_STEP) {
9612+ for (int64_t kk = 0 ; kk < GGML_F32_ARR; kk++) {
9613+ int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
9614+ int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
97179615
9718- GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
9719- GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
9720- GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
9721- GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
9616+ GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD (&r[t_h_j_offset]);
9617+ GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD (&w[t_h_j_offset]);
9618+ GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD (&k[t_h_j_offset]);
9619+ GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD (&b[t_h_j_offset]);
97229620
9723- k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
9621+ k_vec = GGML_F32_VEC_MUL (v_vec, k_vec);
97249622
9725- GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
9726- // kv + s * decay + sa * b
9727- state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
9728- state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
9729- GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
9623+ GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD (&state_prev[h_2d_i_j_offset]);
9624+ // kv + s * decay + sa * b
9625+ state_vec = GGML_F32_VEC_FMA (k_vec, state_vec, w_vec);
9626+ state_vec = GGML_F32_VEC_FMA (state_vec, sa_vec, b_vec);
9627+ GGML_F32_VEC_STORE (&state_cur[h_2d_i_j_offset], state_vec);
97309628
9731- result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
9732- }
9733- }
9734- GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
9735-
9736- // There shouldn't be left-overs though.
9737- for (; j < head_size; j++) {
9738- int64_t t_h_j_offset = t_h_offset + j;
9739- int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9740-
9741- float r_val = r[t_h_j_offset];
9742- float w_val = w[t_h_j_offset];
9743- float k_val = k[t_h_j_offset];
9744- float b_val = b[t_h_j_offset];
9745- float kv_val = v[t_h_i_offset] * k_val;
9746-
9747- float prev_state_val = state_prev[h_2d_i_j_offset];
9748- state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9749- dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
9629+ result_vec[kk] = GGML_F32_VEC_FMA (result_vec[kk], state_vec, r_vec);
97509630 }
97519631 }
9632+ GGML_F32_VEC_REDUCE (dst_data[t_h_i_offset], result_vec);
9633+
9634+ // There shouldn't be left-overs though.
9635+ for (; j < head_size; j++) {
9636+ int64_t t_h_j_offset = t_h_offset + j;
9637+ int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9638+
9639+ float r_val = r[t_h_j_offset];
9640+ float w_val = w[t_h_j_offset];
9641+ float k_val = k[t_h_j_offset];
9642+ float b_val = b[t_h_j_offset];
9643+ float kv_val = v[t_h_i_offset] * k_val;
9644+
9645+ float prev_state_val = state_prev[h_2d_i_j_offset];
9646+ state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9647+ dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
9648+ }
97529649 }
97539650 }
9754- # endif
9651+ }
97559652 #else
97569653 for (int64_t t = 0 ; t < T; t++) {
97579654 int64_t t_offset = t * t_stride;
0 commit comments