Skip to content

Commit 8581c89

Browse files
committed
F32-Mamba-SVE
1 parent 3e0be1c commit 8581c89

File tree

4 files changed

+478
-104
lines changed

4 files changed

+478
-104
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 110 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7530,39 +7530,85 @@ static void ggml_compute_forward_ssm_scan_f32(
75307530
const int ir1 = MIN(ir0 + dr, nr);
75317531
const int ir = ir1 - ir0;
75327532

7533-
for (int i3 = 0; i3 < n_s; ++i3) {
7534-
for (int i2 = 0; i2 < n_t; ++i2) {
7535-
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7536-
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7537-
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7538-
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7539-
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7540-
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7541-
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7542-
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7543-
7544-
// use the output as the source for the next token-wise iterations
7545-
if (i2 > 0) { s0 = s; }
7546-
7547-
// d_inner
7548-
for (int i1 = 0; i1 < ir; ++i1) {
7549-
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7550-
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7551-
float x_dt = x[i1] * dt_soft_plus;
7552-
float sumf = 0.0f;
7553-
// d_state
7554-
for (int i0 = 0; i0 < nc; ++i0) {
7555-
int i = i0 + i1*nc;
7556-
// state = prev_state * dA + dB * x
7557-
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7558-
// y = rowwise_dotprod(state, C)
7559-
sumf += state * C[i0];
7560-
s[i] = state;
7533+
#ifdef __ARM_FEATURE_SVE
7534+
for (int i3 = 0; i3 < n_s; ++i3) {
7535+
for (int i2 = 0; i2 < n_t; ++i2) {
7536+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7537+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7538+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7539+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7540+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7541+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7542+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7543+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7544+
7545+
// use the output as the source for the next token-wise iterations
7546+
if (i2 > 0) { s0 = s; }
7547+
7548+
// d_inner
7549+
for (int i1 = 0; i1 < ir; ++i1) {
7550+
7551+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7552+
float x_dt = x[i1] * dt_soft_plus;
7553+
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
7554+
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
7555+
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
7556+
7557+
for(int64_t k=0; k < nc; k += svcntw()) {
7558+
7559+
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc+k]);
7560+
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
7561+
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
7562+
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc+k]);
7563+
7564+
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus,vA);
7565+
t1 = exp_ps_sve(svptrue_b32(), t1);
7566+
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt,vB);
7567+
7568+
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
7569+
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
7570+
7571+
GGML_F32_VEC_STORE(&s[i1*nc+k], vs0);
7572+
}
7573+
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
75617574
}
7562-
y[i1] = sumf;
75637575
}
7564-
}
75657576
}
7577+
#else
7578+
for (int i3 = 0; i3 < n_s; ++i3) {
7579+
for (int i2 = 0; i2 < n_t; ++i2) {
7580+
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
7581+
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7582+
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
7583+
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
7584+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
7585+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
7586+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
7587+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
7588+
7589+
// use the output as the source for the next token-wise iterations
7590+
if (i2 > 0) { s0 = s; }
7591+
7592+
// d_inner
7593+
for (int i1 = 0; i1 < ir; ++i1) {
7594+
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
7595+
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
7596+
float x_dt = x[i1] * dt_soft_plus;
7597+
float sumf = 0.0f;
7598+
// d_state
7599+
for (int i0 = 0; i0 < nc; ++i0) {
7600+
int i = i0 + i1*nc;
7601+
// state = prev_state * dA + dB * x
7602+
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
7603+
// y = rowwise_dotprod(state, C)
7604+
sumf += state * C[i0];
7605+
s[i] = state;
7606+
}
7607+
y[i1] = sumf;
7608+
}
7609+
}
7610+
}
7611+
#endif
75667612
}
75677613

75687614
void ggml_compute_forward_ssm_scan(
@@ -7963,6 +8009,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
79638009
#define GGML_F32X_MUL GGML_F32x16_MUL
79648010
#define GGML_F32X_FMA GGML_F32x16_FMA
79658011
#define WKV_VECTOR_SIZE 16
8012+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8013+
#define GGML_F32X GGML_F32xt
8014+
#define GGML_F32X_SET1 GGML_F32xt_SET1
8015+
#define GGML_F32X_LOAD GGML_F32xt_LOAD
8016+
#define GGML_F32X_STORE GGML_F32xt_STORE
8017+
#define GGML_F32X_MUL GGML_F32xt_MUL
8018+
#define GGML_F32X_FMA GGML_F32xt_FMA
8019+
#define WKV_VECTOR_SIZE 8
79668020
#elif defined(__ARM_NEON) && defined(__aarch64__)
79678021
#define GGML_F32X GGML_F32x4
79688022
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -7973,8 +8027,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
79738027
#define WKV_VECTOR_SIZE 4
79748028
#endif
79758029

8030+
int wkv_vector_size;
79768031
#ifdef WKV_VECTOR_SIZE
7977-
const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
8032+
#if defined(__ARM_FEATURE_SVE)
8033+
wkv_vector_size = svcntw();
8034+
#else
8035+
wkv_vector_size = WKV_VECTOR_SIZE;
8036+
#endif
8037+
const int64_t vec_count = head_size / wkv_vector_size;
79788038

79798039
for (int64_t t = 0; t < T; t++) {
79808040
size_t t_offset = t * t_stride;
@@ -8004,7 +8064,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80048064
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
80058065

80068066
for (int64_t j = 0; j < vec_count; j++) {
8007-
size_t base_j = j * WKV_VECTOR_SIZE;
8067+
size_t base_j = j * wkv_vector_size;
80088068
size_t t_h_j_offset = t_h_offset + base_j;
80098069
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
80108070

@@ -8029,7 +8089,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80298089
}
80308090

80318091
// Handle remaining elements, this will not be used.
8032-
for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
8092+
for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
80338093
size_t t_h_j_offset = t_h_offset + j;
80348094
size_t h_2d_i_j_offset = h_2d_i_offset + j;
80358095
float v_val = v[t_h_j_offset];
@@ -8165,6 +8225,14 @@ static void ggml_compute_forward_gla_f32(
81658225
#define GGML_F32X_MUL GGML_F32x16_MUL
81668226
#define GGML_F32X_FMA GGML_F32x16_FMA
81678227
#define GLA_VECTOR_SIZE 16
8228+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8229+
#define GGML_F32X GGML_F32xt
8230+
#define GGML_F32X_SET1 GGML_F32xt_SET1
8231+
#define GGML_F32X_LOAD GGML_F32xt_LOAD
8232+
#define GGML_F32X_STORE GGML_F32xt_STORE
8233+
#define GGML_F32X_MUL GGML_F32xt_MUL
8234+
#define GGML_F32X_FMA GGML_F32xt_FMA
8235+
#define GLA_VECTOR_SIZE 8
81688236
#elif defined(__ARM_NEON) && defined(__aarch64__)
81698237
#define GGML_F32X GGML_F32x4
81708238
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -8174,9 +8242,14 @@ static void ggml_compute_forward_gla_f32(
81748242
#define GGML_F32X_FMA GGML_F32x4_FMA
81758243
#define GLA_VECTOR_SIZE 4
81768244
#endif
8177-
8245+
int gla_vector_size;
81788246
#ifdef GLA_VECTOR_SIZE
8179-
const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
8247+
#if defined(__ARM_FEATURE_SVE)
8248+
gla_vector_size = svcntw();
8249+
#else
8250+
gla_vector_size = GLA_VECTOR_SIZE
8251+
#endif
8252+
const int64_t vec_count = head_size / gla_vector_size;
81808253

81818254
for (int64_t t = 0; t < T; t++) {
81828255
size_t t_offset = t * t_stride;
@@ -8203,7 +8276,7 @@ static void ggml_compute_forward_gla_f32(
82038276
GGML_F32X g_vec = GGML_F32X_SET1(g_val);
82048277

82058278
for (int64_t j = 0; j < vec_count; j++) {
8206-
size_t base_j = j * GLA_VECTOR_SIZE;
8279+
size_t base_j = j * gla_vector_size;
82078280
size_t t_h_j_offset = t_h_offset + base_j;
82088281
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
82098282

@@ -8227,7 +8300,7 @@ static void ggml_compute_forward_gla_f32(
82278300
}
82288301

82298302
// Handle remaining elements, this will not be used.
8230-
for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
8303+
for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
82318304
size_t t_h_j_offset = t_h_offset + j;
82328305
size_t h_2d_i_j_offset = h_2d_i_offset + j;
82338306
float v_val = v[t_h_j_offset];

ggml/src/ggml-cpu/simd-mappings.h

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,122 @@
1717
// number of elements to fit in a single register
1818
//
1919

20-
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
20+
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_FMA)
21+
22+
#define GGML_SIMD
23+
24+
// F32 SVE
25+
#define DEFAULT_PG svptrue_b32()
26+
27+
#define GGML_F32xt svfloat32_t
28+
#define GGML_F32xt_ZERO svdup_n_f32(0.0f)
29+
#define GGML_F32xt_SET1(x) svdup_n_f32(x)
30+
#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a)
31+
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
32+
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
33+
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
34+
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c)
35+
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
36+
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
37+
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
38+
#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b)
39+
#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__)
40+
#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)
41+
#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__)
42+
#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
43+
{ \
44+
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \
45+
sum3 = svadd_f32_m(DEFAULT_PG, sum3, sum4); \
46+
sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum6); \
47+
sum7 = svadd_f32_m(DEFAULT_PG, sum7, sum8); \
48+
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum3); \
49+
sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum7); \
50+
sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \
51+
(res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \
52+
}
53+
#define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__)
54+
55+
#define GGML_F32_VEC GGML_F32xt
56+
#define GGML_F32_VEC_ZERO GGML_F32xt_ZERO
57+
#define GGML_F32_VEC_SET1 GGML_F32xt_SET1
58+
#define GGML_F32_VEC_LOAD GGML_F32xt_LOAD
59+
#define GGML_F32_VEC_STORE GGML_F32xt_STORE
60+
#define GGML_F32_VEC_FMA GGML_F32xt_FMA
61+
#define GGML_F32_VEC_ADD GGML_F32xt_ADD
62+
#define GGML_F32_VEC_MUL GGML_F32xt_MUL
63+
#define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE
64+
65+
// F16 NEON
66+
67+
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
68+
#define GGML_F16_STEP 32
69+
#define GGML_F16_EPR 8
70+
71+
#define GGML_F16x8 float16x8_t
72+
#define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
73+
#define GGML_F16x8_SET1(x) vdupq_n_f16(x)
74+
#define GGML_F16x8_LOAD(x) vld1q_f16((const ggml_fp16_internal_t *)(x))
75+
#define GGML_F16x8_STORE vst1q_f16
76+
#define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
77+
#define GGML_F16x8_ADD vaddq_f16
78+
#define GGML_F16x8_MUL vmulq_f16
79+
#define GGML_F16x8_REDUCE(res, x) \
80+
do { \
81+
int offset = GGML_F16_ARR >> 1; \
82+
for (int i = 0; i < offset; ++i) { \
83+
(x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
84+
} \
85+
offset >>= 1; \
86+
for (int i = 0; i < offset; ++i) { \
87+
(x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
88+
} \
89+
offset >>= 1; \
90+
for (int i = 0; i < offset; ++i) { \
91+
(x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
92+
} \
93+
const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \
94+
const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \
95+
(res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
96+
} while (0)
97+
98+
#define GGML_F16_VEC GGML_F16x8
99+
#define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
100+
#define GGML_F16_VEC_SET1 GGML_F16x8_SET1
101+
#define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
102+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), (r)[i])
103+
#define GGML_F16_VEC_FMA GGML_F16x8_FMA
104+
#define GGML_F16_VEC_ADD GGML_F16x8_ADD
105+
#define GGML_F16_VEC_MUL GGML_F16x8_MUL
106+
#define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
107+
#else
108+
// if FP16 vector arithmetic is not supported, we use FP32 instead
109+
// and take advantage of the vcvt_ functions to convert to/from FP16
110+
111+
#define GGML_F16_STEP 16
112+
#define GGML_F16_EPR 4
113+
114+
#define GGML_F32Cx4 float32x4_t
115+
#define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
116+
#define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
117+
#define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x)))
118+
#define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
119+
#define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
120+
#define GGML_F32Cx4_ADD vaddq_f32
121+
#define GGML_F32Cx4_MUL vmulq_f32
122+
#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
123+
124+
#define GGML_F16_VEC GGML_F32Cx4
125+
#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
126+
#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
127+
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
128+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
129+
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
130+
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
131+
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
132+
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
133+
#endif
134+
135+
#elif defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
21136

22137
#define GGML_SIMD
23138

0 commit comments

Comments
 (0)