@@ -7633,39 +7633,83 @@ static void ggml_compute_forward_ssm_scan_f32(
76337633 const int ir1 = MIN (ir0 + dr, nr);
76347634 const int ir = ir1 - ir0;
76357635
7636- for (int i3 = 0 ; i3 < n_s; ++i3) {
7637- for (int i2 = 0 ; i2 < n_t ; ++i2) {
7638- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ])); // {d_state, d_inner, n_s}
7639- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7640- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb [0 ]) + i2*(src2->nb [1 ]) + i3*(src2->nb [2 ])); // {d_inner, n_t, n_s}
7641- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb [1 ])); // {d_state, d_inner}
7642- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb [1 ]) + i3*(src4->nb [2 ])); // {d_state, n_t, n_s}
7643- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb [1 ]) + i3*(src5->nb [2 ])); // {d_state, n_t, n_s}
7644- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7645- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
7646-
7647- // use the output as the source for the next token-wise iterations
7648- if (i2 > 0 ) { s0 = s; }
7649-
7650- // d_inner
7651- for (int i1 = 0 ; i1 < ir; ++i1) {
7652- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7653- float dt_soft_plus = dt[i1] <= 20 .0f ? log1pf (expf (dt[i1])) : dt[i1];
7654- float x_dt = x[i1] * dt_soft_plus;
7655- float sumf = 0 .0f ;
7656- // d_state
7657- for (int i0 = 0 ; i0 < nc; ++i0) {
7658- int i = i0 + i1*nc;
7659- // state = prev_state * dA + dB * x
7660- float state = (s0[i] * expf (dt_soft_plus * A[i])) + (B[i0] * x_dt);
7661- // y = rowwise_dotprod(state, C)
7662- sumf += state * C[i0];
7663- s[i] = state;
7636+ #ifdef __ARM_FEATURE_SVE
7637+ for (int i3 = 0 ; i3 < n_s; ++i3) {
7638+ for (int i2 = 0 ; i2 < n_t ; ++i2) {
7639+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ])); // {d_state, d_inner, n_s}
7640+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7641+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb [0 ]) + i2*(src2->nb [1 ]) + i3*(src2->nb [2 ])); // {d_inner, n_t, n_s}
7642+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb [1 ])); // {d_state, d_inner}
7643+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb [1 ]) + i3*(src4->nb [2 ])); // {d_state, n_t, n_s}
7644+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb [1 ]) + i3*(src5->nb [2 ])); // {d_state, n_t, n_s}
7645+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7646+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
7647+
7648+ // use the output as the source for the next token-wise iterations
7649+ if (i2 > 0 ) { s0 = s; }
7650+
7651+ // d_inner
7652+ for (int i1 = 0 ; i1 < ir; ++i1) {
7653+ float dt_soft_plus = dt[i1] <= 20 .0f ? log1pf (expf (dt[i1])) : dt[i1];
7654+ float x_dt = x[i1] * dt_soft_plus;
7655+ svfloat32_t vx_dt = GGML_F32_VEC_SET1 (x_dt);
7656+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1 (dt_soft_plus);
7657+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
7658+
7659+ for (int64_t k = 0 ; k < nc; k += svcntw ()) {
7660+ svfloat32_t vA = GGML_F32_VEC_LOAD (&A[i1*nc + k]);
7661+ svfloat32_t vB = GGML_F32_VEC_LOAD (&B[k]);
7662+ svfloat32_t vC = GGML_F32_VEC_LOAD (&C[k]);
7663+ svfloat32_t vs0 = GGML_F32_VEC_LOAD (&s0[i1*nc + k]);
7664+
7665+ svfloat32_t t1 = GGML_F32_VEC_MUL (vdt_soft_plus, vA);
7666+ t1 = exp_ps_sve (svptrue_b32 (), t1);
7667+ svfloat32_t t2 = GGML_F32_VEC_MUL (vx_dt, vB);
7668+
7669+ vs0 = GGML_F32_VEC_FMA (vs0, t1, t2);
7670+ r1_vector = GGML_F32_VEC_ADD (GGML_F32_VEC_MUL (vs0, vC), r1_vector);
7671+
7672+ GGML_F32_VEC_STORE (&s[i1*nc + k], vs0);
7673+ }
7674+ y[i1] = GGML_F32xt_REDUCE_ONE (r1_vector);
76647675 }
7665- y[i1] = sumf;
76667676 }
76677677 }
7668- }
7678+ #else
7679+ for (int i3 = 0 ; i3 < n_s; ++i3) {
7680+ for (int i2 = 0 ; i2 < n_t ; ++i2) {
7681+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ])); // {d_state, d_inner, n_s}
7682+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7683+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb [0 ]) + i2*(src2->nb [1 ]) + i3*(src2->nb [2 ])); // {d_inner, n_t, n_s}
7684+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb [1 ])); // {d_state, d_inner}
7685+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb [1 ]) + i3*(src4->nb [2 ])); // {d_state, n_t, n_s}
7686+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb [1 ]) + i3*(src5->nb [2 ])); // {d_state, n_t, n_s}
7687+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7688+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
7689+
7690+ // use the output as the source for the next token-wise iterations
7691+ if (i2 > 0 ) { s0 = s; }
7692+
7693+ // d_inner
7694+ for (int i1 = 0 ; i1 < ir; ++i1) {
7695+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7696+ float dt_soft_plus = dt[i1] <= 20 .0f ? log1pf (expf (dt[i1])) : dt[i1];
7697+ float x_dt = x[i1] * dt_soft_plus;
7698+ float sumf = 0 .0f ;
7699+ // d_state
7700+ for (int i0 = 0 ; i0 < nc; ++i0) {
7701+ int i = i0 + i1*nc;
7702+ // state = prev_state * dA + dB * x
7703+ float state = (s0[i] * expf (dt_soft_plus * A[i])) + (B[i0] * x_dt);
7704+ // y = rowwise_dotprod(state, C)
7705+ sumf += state * C[i0];
7706+ s[i] = state;
7707+ }
7708+ y[i1] = sumf;
7709+ }
7710+ }
7711+ }
7712+ #endif
76697713}
76707714
76717715void ggml_compute_forward_ssm_scan (
0 commit comments