Skip to content

Commit 55b5545

Browse files
committed
F32-Mamba-SVE
1 parent 8581c89 commit 55b5545

File tree

3 files changed

+115
-70
lines changed

3 files changed

+115
-70
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 107 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -8409,83 +8409,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
84098409
int64_t h_stride_2d = head_size * head_size;
84108410

84118411
#if defined(GGML_SIMD)
8412-
for (int64_t t = 0; t < T; t++) {
8413-
int64_t t_offset = t * t_stride;
8414-
int64_t state_offset = head_size * C * (t / (T / n_seqs));
8415-
float * state_cur = state + state_offset;
8416-
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8417-
8418-
for (int64_t h = h_start; h < h_end; h++) {
8419-
int64_t h_offset = h * h_stride;
8420-
int64_t t_h_offset = t_offset + h_offset;
8421-
int64_t h_2d_offset = h * h_stride_2d;
8422-
8423-
for (int64_t ii = 0; ii < head_size; ii++) {
8424-
int64_t t_h_i_offset = t_h_offset + ii;
8425-
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8426-
8427-
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
8412+
#if defined(__ARM_FEATURE_SVE)
8413+
// scalar Route to scalar implementation //TODO: Write SVE code
8414+
for (int64_t t = 0; t < T; t++) {
8415+
int64_t t_offset = t * t_stride;
8416+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
8417+
float * state_cur = state + state_offset;
8418+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8419+
8420+
for (int64_t h = h_start; h < h_end; h++) {
8421+
int64_t h_offset = h * h_stride;
8422+
int64_t t_h_offset = t_offset + h_offset;
8423+
int64_t h_2d_offset = h * h_stride_2d;
8424+
8425+
for (int64_t i = 0; i < head_size; i++) {
8426+
int64_t t_h_i_offset = t_h_offset + i;
8427+
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
8428+
8429+
float v_val = v[t_h_i_offset];
8430+
8431+
float sa = 0, result = 0;
8432+
for (int64_t j = 0; j < head_size; j++) {
8433+
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
8434+
}
84288435

8429-
float sa = 0;
8430-
{
8431-
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8432-
GGML_F32_VEC ax[GGML_F32_ARR];
8433-
GGML_F32_VEC ay[GGML_F32_ARR];
8434-
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
8435-
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8436-
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
8437-
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8438-
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
8439-
}
8436+
for (int64_t j = 0; j < head_size; j++) {
8437+
int64_t t_h_j_offset = t_h_offset + j;
8438+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8439+
8440+
float r_val = r[t_h_j_offset];
8441+
float w_val = w[t_h_j_offset];
8442+
float k_val = k[t_h_j_offset];
8443+
float b_val = b[t_h_j_offset];
8444+
float kv_val = v_val * k_val;
8445+
float prev_state_val = state_prev[h_2d_i_j_offset];
8446+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8447+
result += state_cur[h_2d_i_j_offset] * r_val;
84408448
}
8441-
GGML_F32_VEC_REDUCE(sa, sum);
8449+
dst_data[t_h_i_offset] = result;
84428450
}
8451+
}
8452+
}
8453+
#else
8454+
for (int64_t t = 0; t < T; t++) {
8455+
int64_t t_offset = t * t_stride;
8456+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
8457+
float * state_cur = state + state_offset;
8458+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8459+
8460+
for (int64_t h = h_start; h < h_end; h++) {
8461+
int64_t h_offset = h * h_stride;
8462+
int64_t t_h_offset = t_offset + h_offset;
8463+
int64_t h_2d_offset = h * h_stride_2d;
8464+
8465+
for (int64_t ii = 0; ii < head_size; ii++) {
8466+
int64_t t_h_i_offset = t_h_offset + ii;
8467+
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8468+
8469+
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
8470+
8471+
float sa = 0;
8472+
{
8473+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8474+
GGML_F32_VEC ax[GGML_F32_ARR];
8475+
GGML_F32_VEC ay[GGML_F32_ARR];
8476+
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
8477+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8478+
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
8479+
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8480+
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
8481+
}
8482+
}
8483+
GGML_F32_VEC_REDUCE(sa, sum);
8484+
}
84438485

8444-
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
8486+
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
84458487

8446-
int64_t j = 0;
8447-
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8448-
for (; j < head_size; j += GGML_F32_STEP) {
8449-
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8450-
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8451-
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
8488+
int64_t j = 0;
8489+
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8490+
for (; j < head_size; j += GGML_F32_STEP) {
8491+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8492+
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8493+
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
84528494

8453-
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
8454-
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
8455-
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
8456-
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
8495+
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
8496+
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
8497+
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
8498+
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
84578499

8458-
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
8500+
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
84598501

8460-
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
8461-
// kv + s * decay + sa * b
8462-
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
8463-
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
8464-
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
8502+
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
8503+
// kv + s * decay + sa * b
8504+
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
8505+
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
8506+
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
84658507

8466-
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
8508+
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
8509+
}
8510+
}
8511+
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
8512+
8513+
// There shouldn't be left-overs though.
8514+
for (; j < head_size; j++) {
8515+
int64_t t_h_j_offset = t_h_offset + j;
8516+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8517+
8518+
float r_val = r[t_h_j_offset];
8519+
float w_val = w[t_h_j_offset];
8520+
float k_val = k[t_h_j_offset];
8521+
float b_val = b[t_h_j_offset];
8522+
float kv_val = v[t_h_i_offset] * k_val;
8523+
8524+
float prev_state_val = state_prev[h_2d_i_j_offset];
8525+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8526+
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
84678527
}
8468-
}
8469-
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
8470-
8471-
// There shouldn't be left-overs though.
8472-
for (; j < head_size; j++) {
8473-
int64_t t_h_j_offset = t_h_offset + j;
8474-
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8475-
8476-
float r_val = r[t_h_j_offset];
8477-
float w_val = w[t_h_j_offset];
8478-
float k_val = k[t_h_j_offset];
8479-
float b_val = b[t_h_j_offset];
8480-
float kv_val = v[t_h_i_offset] * k_val;
8481-
8482-
float prev_state_val = state_prev[h_2d_i_j_offset];
8483-
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8484-
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
84858528
}
84868529
}
84878530
}
8488-
}
8531+
#endif
84898532
#else
84908533
for (int64_t t = 0; t < T; t++) {
84918534
int64_t t_offset = t * t_stride;

ggml/src/ggml-cpu/simd-mappings.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#define GGML_SIMD
2323

2424
// F32 SVE
25+
#define GGML_F32_EPR 8
2526
#define DEFAULT_PG svptrue_b32()
2627

2728
#define GGML_F32xt svfloat32_t
@@ -71,7 +72,7 @@
7172
#define GGML_F16x8 float16x8_t
7273
#define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
7374
#define GGML_F16x8_SET1(x) vdupq_n_f16(x)
74-
#define GGML_F16x8_LOAD(x) vld1q_f16((const ggml_fp16_internal_t *)(x))
75+
#define GGML_F16x8_LOAD(x) vld1q_f16((const __fp16 *)(x))
7576
#define GGML_F16x8_STORE vst1q_f16
7677
#define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
7778
#define GGML_F16x8_ADD vaddq_f16
@@ -99,7 +100,7 @@
99100
#define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
100101
#define GGML_F16_VEC_SET1 GGML_F16x8_SET1
101102
#define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
102-
#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), (r)[i])
103+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), (r)[i])
103104
#define GGML_F16_VEC_FMA GGML_F16x8_FMA
104105
#define GGML_F16_VEC_ADD GGML_F16x8_ADD
105106
#define GGML_F16_VEC_MUL GGML_F16x8_MUL
@@ -114,7 +115,7 @@
114115
#define GGML_F32Cx4 float32x4_t
115116
#define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
116117
#define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
117-
#define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x)))
118+
#define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const __fp16 *)(x)))
118119
#define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
119120
#define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
120121
#define GGML_F32Cx4_ADD vaddq_f32
@@ -125,7 +126,7 @@
125126
#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
126127
#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
127128
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
128-
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
129+
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i])
129130
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
130131
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
131132
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL

ggml/src/ggml-cpu/vec.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
592592
/* Below function was borrowed from the GitHub repository:
593593
https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */
594594
#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
595-
svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) {
595+
inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) {
596596
// Constants
597597
const svfloat32_t log2_e = svdup_n_f32(1.4426950409f);
598598
const svfloat32_t ln2 = svdup_n_f32(0.6931473921f);
@@ -623,8 +623,9 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
623623

624624
return t0;
625625
}
626+
#endif
626627

627-
#elif defined(__ARM_NEON) && defined(__aarch64__)
628+
#if defined(__ARM_NEON) && defined(__aarch64__)
628629

629630
// adapted from arm limited optimized routine
630631
// the maximum error is 1.45358 plus 0.5 ulps

0 commit comments

Comments
 (0)