@@ -7530,39 +7530,85 @@ static void ggml_compute_forward_ssm_scan_f32(
75307530 const int ir1 = MIN (ir0 + dr, nr);
75317531 const int ir = ir1 - ir0;
75327532
7533- for (int i3 = 0 ; i3 < n_s; ++i3) {
7534- for (int i2 = 0 ; i2 < n_t ; ++i2) {
7535- const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ])); // {d_state, d_inner, n_s}
7536- const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7537- const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb [0 ]) + i2*(src2->nb [1 ]) + i3*(src2->nb [2 ])); // {d_inner, n_t, n_s}
7538- const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb [1 ])); // {d_state, d_inner}
7539- const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb [1 ]) + i3*(src4->nb [2 ])); // {d_state, n_t, n_s}
7540- const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb [1 ]) + i3*(src5->nb [2 ])); // {d_state, n_t, n_s}
7541- float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7542- float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
7543-
7544- // use the output as the source for the next token-wise iterations
7545- if (i2 > 0 ) { s0 = s; }
7546-
7547- // d_inner
7548- for (int i1 = 0 ; i1 < ir; ++i1) {
7549- // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7550- float dt_soft_plus = dt[i1] <= 20 .0f ? log1pf (expf (dt[i1])) : dt[i1];
7551- float x_dt = x[i1] * dt_soft_plus;
7552- float sumf = 0 .0f ;
7553- // d_state
7554- for (int i0 = 0 ; i0 < nc; ++i0) {
7555- int i = i0 + i1*nc;
7556- // state = prev_state * dA + dB * x
7557- float state = (s0[i] * expf (dt_soft_plus * A[i])) + (B[i0] * x_dt);
7558- // y = rowwise_dotprod(state, C)
7559- sumf += state * C[i0];
7560- s[i] = state;
7533+ #ifdef __ARM_FEATURE_SVE
7534+ for (int i3 = 0 ; i3 < n_s; ++i3) {
7535+ for (int i2 = 0 ; i2 < n_t ; ++i2) {
7536+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ])); // {d_state, d_inner, n_s}
7537+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7538+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb [0 ]) + i2*(src2->nb [1 ]) + i3*(src2->nb [2 ])); // {d_inner, n_t, n_s}
7539+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb [1 ])); // {d_state, d_inner}
7540+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb [1 ]) + i3*(src4->nb [2 ])); // {d_state, n_t, n_s}
7541+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb [1 ]) + i3*(src5->nb [2 ])); // {d_state, n_t, n_s}
7542+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7543+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
7544+
7545+ // use the output as the source for the next token-wise iterations
7546+ if (i2 > 0 ) { s0 = s; }
7547+
7548+ // d_inner
7549+ for (int i1 = 0 ; i1 < ir; ++i1) {
7550+
7551+ float dt_soft_plus = dt[i1] <= 20 .0f ? log1pf (expf (dt[i1])) : dt[i1];
7552+ float x_dt = x[i1] * dt_soft_plus;
7553+ svfloat32_t vx_dt = GGML_F32_VEC_SET1 (x_dt);
7554+ svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1 (dt_soft_plus);
7555+ svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
7556+
7557+ for (int64_t k=0 ; k < nc; k += svcntw ()) {
7558+
7559+ svfloat32_t vA = GGML_F32_VEC_LOAD (&A[i1*nc+k]);
7560+ svfloat32_t vB = GGML_F32_VEC_LOAD (&B[k]);
7561+ svfloat32_t vC = GGML_F32_VEC_LOAD (&C[k]);
7562+ svfloat32_t vs0 = GGML_F32_VEC_LOAD (&s0[i1*nc+k]);
7563+
7564+ svfloat32_t t1 = GGML_F32_VEC_MUL (vdt_soft_plus,vA);
7565+ t1 = exp_ps_sve (svptrue_b32 (), t1);
7566+ svfloat32_t t2 = GGML_F32_VEC_MUL (vx_dt,vB);
7567+
7568+ vs0 = GGML_F32_VEC_FMA (vs0, t1, t2);
7569+ r1_vector = GGML_F32_VEC_ADD (GGML_F32_VEC_MUL (vs0, vC), r1_vector);
7570+
7571+ GGML_F32_VEC_STORE (&s[i1*nc+k], vs0);
7572+ }
7573+ y[i1] = GGML_F32xt_REDUCE_ONE (r1_vector);
75617574 }
7562- y[i1] = sumf;
75637575 }
7564- }
75657576 }
7577+ #else
7578+ for (int i3 = 0 ; i3 < n_s; ++i3) {
7579+ for (int i2 = 0 ; i2 < n_t ; ++i2) {
7580+ const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ])); // {d_state, d_inner, n_s}
7581+ const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7582+ const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb [0 ]) + i2*(src2->nb [1 ]) + i3*(src2->nb [2 ])); // {d_inner, n_t, n_s}
7583+ const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb [1 ])); // {d_state, d_inner}
7584+ const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb [1 ]) + i3*(src4->nb [2 ])); // {d_state, n_t, n_s}
7585+ const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb [1 ]) + i3*(src5->nb [2 ])); // {d_state, n_t, n_s}
7586+ float * y = ( float *) (( char *) dst->data + ir0*(src1->nb [0 ]) + i2*(src1->nb [1 ]) + i3*(src1->nb [2 ])); // {d_inner, n_t, n_s}
7587+ float * s = ( float *) (( char *) dst->data + ir0*(src0->nb [1 ]) + i3*(src0->nb [2 ]) + src1->nb [3 ]); // {d_state, d_inner, n_s}
7588+
7589+ // use the output as the source for the next token-wise iterations
7590+ if (i2 > 0 ) { s0 = s; }
7591+
7592+ // d_inner
7593+ for (int i1 = 0 ; i1 < ir; ++i1) {
7594+ // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7595+ float dt_soft_plus = dt[i1] <= 20 .0f ? log1pf (expf (dt[i1])) : dt[i1];
7596+ float x_dt = x[i1] * dt_soft_plus;
7597+ float sumf = 0 .0f ;
7598+ // d_state
7599+ for (int i0 = 0 ; i0 < nc; ++i0) {
7600+ int i = i0 + i1*nc;
7601+ // state = prev_state * dA + dB * x
7602+ float state = (s0[i] * expf (dt_soft_plus * A[i])) + (B[i0] * x_dt);
7603+ // y = rowwise_dotprod(state, C)
7604+ sumf += state * C[i0];
7605+ s[i] = state;
7606+ }
7607+ y[i1] = sumf;
7608+ }
7609+ }
7610+ }
7611+ #endif
75667612}
75677613
75687614void ggml_compute_forward_ssm_scan (
@@ -7963,6 +8009,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
79638009 #define GGML_F32X_MUL GGML_F32x16_MUL
79648010 #define GGML_F32X_FMA GGML_F32x16_FMA
79658011 #define WKV_VECTOR_SIZE 16
8012+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8013+ #define GGML_F32X GGML_F32xt
8014+ #define GGML_F32X_SET1 GGML_F32xt_SET1
8015+ #define GGML_F32X_LOAD GGML_F32xt_LOAD
8016+ #define GGML_F32X_STORE GGML_F32xt_STORE
8017+ #define GGML_F32X_MUL GGML_F32xt_MUL
8018+ #define GGML_F32X_FMA GGML_F32xt_FMA
8019+ #define WKV_VECTOR_SIZE 8
79668020 #elif defined(__ARM_NEON) && defined(__aarch64__)
79678021 #define GGML_F32X GGML_F32x4
79688022 #define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -7973,8 +8027,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
79738027 #define WKV_VECTOR_SIZE 4
79748028 #endif
79758029
8030+ int wkv_vector_size;
79768031 #ifdef WKV_VECTOR_SIZE
7977- const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
8032+ #if defined(__ARM_FEATURE_SVE)
8033+ wkv_vector_size = svcntw ();
8034+ #else
8035+ wkv_vector_size = WKV_VECTOR_SIZE;
8036+ #endif
8037+ const int64_t vec_count = head_size / wkv_vector_size;
79788038
79798039 for (int64_t t = 0 ; t < T; t++) {
79808040 size_t t_offset = t * t_stride;
@@ -8004,7 +8064,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80048064 GGML_F32X time_decay_vec = GGML_F32X_SET1 (time_decay_val);
80058065
80068066 for (int64_t j = 0 ; j < vec_count; j++) {
8007- size_t base_j = j * WKV_VECTOR_SIZE ;
8067+ size_t base_j = j * wkv_vector_size ;
80088068 size_t t_h_j_offset = t_h_offset + base_j;
80098069 size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
80108070
@@ -8029,7 +8089,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80298089 }
80308090
80318091 // Handle remaining elements, this will not be used.
8032- for (int64_t j = vec_count * WKV_VECTOR_SIZE ; j < head_size; j++) {
8092+ for (int64_t j = vec_count * wkv_vector_size ; j < head_size; j++) {
80338093 size_t t_h_j_offset = t_h_offset + j;
80348094 size_t h_2d_i_j_offset = h_2d_i_offset + j;
80358095 float v_val = v[t_h_j_offset];
@@ -8165,6 +8225,14 @@ static void ggml_compute_forward_gla_f32(
81658225 #define GGML_F32X_MUL GGML_F32x16_MUL
81668226 #define GGML_F32X_FMA GGML_F32x16_FMA
81678227 #define GLA_VECTOR_SIZE 16
8228+ #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8229+ #define GGML_F32X GGML_F32xt
8230+ #define GGML_F32X_SET1 GGML_F32xt_SET1
8231+ #define GGML_F32X_LOAD GGML_F32xt_LOAD
8232+ #define GGML_F32X_STORE GGML_F32xt_STORE
8233+ #define GGML_F32X_MUL GGML_F32xt_MUL
8234+ #define GGML_F32X_FMA GGML_F32xt_FMA
8235+ #define GLA_VECTOR_SIZE 8
81688236 #elif defined(__ARM_NEON) && defined(__aarch64__)
81698237 #define GGML_F32X GGML_F32x4
81708238 #define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -8174,9 +8242,14 @@ static void ggml_compute_forward_gla_f32(
81748242 #define GGML_F32X_FMA GGML_F32x4_FMA
81758243 #define GLA_VECTOR_SIZE 4
81768244 #endif
8177-
8245+ int gla_vector_size;
81788246 #ifdef GLA_VECTOR_SIZE
8179- const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
8247+ #if defined(__ARM_FEATURE_SVE)
8248+ gla_vector_size = svcntw ();
8249+ #else
8250+ gla_vector_size = GLA_VECTOR_SIZE
8251+ #endif
8252+ const int64_t vec_count = head_size / gla_vector_size;
81808253
81818254 for (int64_t t = 0 ; t < T; t++) {
81828255 size_t t_offset = t * t_stride;
@@ -8203,7 +8276,7 @@ static void ggml_compute_forward_gla_f32(
82038276 GGML_F32X g_vec = GGML_F32X_SET1 (g_val);
82048277
82058278 for (int64_t j = 0 ; j < vec_count; j++) {
8206- size_t base_j = j * GLA_VECTOR_SIZE ;
8279+ size_t base_j = j * gla_vector_size ;
82078280 size_t t_h_j_offset = t_h_offset + base_j;
82088281 size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
82098282
@@ -8227,7 +8300,7 @@ static void ggml_compute_forward_gla_f32(
82278300 }
82288301
82298302 // Handle remaining elements, this will not be used.
8230- for (int64_t j = vec_count * GLA_VECTOR_SIZE ; j < head_size; j++) {
8303+ for (int64_t j = vec_count * gla_vector_size ; j < head_size; j++) {
82318304 size_t t_h_j_offset = t_h_offset + j;
82328305 size_t h_2d_i_j_offset = h_2d_i_offset + j;
82338306 float v_val = v[t_h_j_offset];
0 commit comments