Skip to content

Commit e6787bb

Browse files
committed
ggml : remove SVE paths
1 parent d9e0e7c commit e6787bb

File tree

4 files changed

+74
-790
lines changed

4 files changed

+74
-790
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 64 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -8647,40 +8647,6 @@ static void ggml_compute_forward_ssm_scan_f32(
86478647
const float x_dt = x[ii] * dt_soft_plus;
86488648
float sumf = 0.0f;
86498649
#if defined(GGML_SIMD)
8650-
#if defined(__ARM_FEATURE_SVE)
8651-
const int ggml_f32_epr = svcntw();
8652-
const int ggml_f32_step = 1 * ggml_f32_epr;
8653-
8654-
const int np = (nc & ~(ggml_f32_step - 1));
8655-
8656-
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
8657-
8658-
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
8659-
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
8660-
8661-
for (int i = 0; i < np; i += ggml_f32_step) {
8662-
// TODO: maybe unroll more?
8663-
for (int j = 0; j < 1; j++) {
8664-
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
8665-
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
8666-
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
8667-
8668-
t0 = GGML_F32_VEC_MUL(t0, adA);
8669-
t1 = GGML_F32_VEC_MUL(t1, axdt);
8670-
8671-
t0 = GGML_F32_VEC_ADD(t0, t1);
8672-
8673-
sum = GGML_F32_VEC_FMA(sum, t0, t2);
8674-
8675-
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
8676-
}
8677-
}
8678-
8679-
sumf = GGML_F32xt_REDUCE_ONE(sum);
8680-
#elif defined(__riscv_v_intrinsic)
8681-
// todo: RVV implementation
8682-
const int np = 0;
8683-
#else
86848650
const int np = (nc & ~(GGML_F32_STEP - 1));
86858651

86868652
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
@@ -8711,7 +8677,6 @@ static void ggml_compute_forward_ssm_scan_f32(
87118677

87128678
// reduce sum0..sum3 to sum0
87138679
GGML_F32_VEC_REDUCE(sumf, sum);
8714-
#endif
87158680
#else
87168681
const int np = 0;
87178682
#endif
@@ -8741,30 +8706,6 @@ static void ggml_compute_forward_ssm_scan_f32(
87418706
for (int i1 = 0; i1 < nr; ++i1) {
87428707
const int ii = i1 + h*nr;
87438708
const float x_dt = x[ii] * dt_soft_plus;
8744-
#if defined(__ARM_FEATURE_SVE)
8745-
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
8746-
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
8747-
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
8748-
8749-
// d_state
8750-
// TODO: what happens when (d_state % svcntw()) != 0?
8751-
for (int64_t k = 0; k < nc; k += svcntw()) {
8752-
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
8753-
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
8754-
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
8755-
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
8756-
8757-
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
8758-
t1 = exp_ps_sve(svptrue_b32(), t1);
8759-
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
8760-
8761-
vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
8762-
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
8763-
8764-
GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
8765-
}
8766-
y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
8767-
#else
87688709
float sumf = 0.0f;
87698710
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
87708711
// and also because expf is used within the loop.
@@ -8779,7 +8720,6 @@ static void ggml_compute_forward_ssm_scan_f32(
87798720
s[i] = state;
87808721
}
87818722
y[ii] = sumf;
8782-
#endif
87838723
}
87848724
}
87858725
}
@@ -9632,126 +9572,83 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
96329572
int64_t h_stride_2d = head_size * head_size;
96339573

96349574
#if defined(GGML_SIMD)
9635-
#if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
9636-
// scalar Route to scalar implementation //TODO: Write SVE code and RVV code
9637-
for (int64_t t = 0; t < T; t++) {
9638-
int64_t t_offset = t * t_stride;
9639-
int64_t state_offset = head_size * C * (t / (T / n_seqs));
9640-
float * state_cur = state + state_offset;
9641-
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
9642-
9643-
for (int64_t h = h_start; h < h_end; h++) {
9644-
int64_t h_offset = h * h_stride;
9645-
int64_t t_h_offset = t_offset + h_offset;
9646-
int64_t h_2d_offset = h * h_stride_2d;
9647-
9648-
for (int64_t i = 0; i < head_size; i++) {
9649-
int64_t t_h_i_offset = t_h_offset + i;
9650-
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
9651-
9652-
float v_val = v[t_h_i_offset];
9653-
9654-
float sa = 0, result = 0;
9655-
for (int64_t j = 0; j < head_size; j++) {
9656-
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
9657-
}
9575+
for (int64_t t = 0; t < T; t++) {
9576+
int64_t t_offset = t * t_stride;
9577+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
9578+
float * state_cur = state + state_offset;
9579+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
96589580

9659-
for (int64_t j = 0; j < head_size; j++) {
9660-
int64_t t_h_j_offset = t_h_offset + j;
9661-
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9662-
9663-
float r_val = r[t_h_j_offset];
9664-
float w_val = w[t_h_j_offset];
9665-
float k_val = k[t_h_j_offset];
9666-
float b_val = b[t_h_j_offset];
9667-
float kv_val = v_val * k_val;
9668-
float prev_state_val = state_prev[h_2d_i_j_offset];
9669-
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9670-
result += state_cur[h_2d_i_j_offset] * r_val;
9671-
}
9672-
dst_data[t_h_i_offset] = result;
9673-
}
9674-
}
9675-
}
9676-
#else
9677-
for (int64_t t = 0; t < T; t++) {
9678-
int64_t t_offset = t * t_stride;
9679-
int64_t state_offset = head_size * C * (t / (T / n_seqs));
9680-
float * state_cur = state + state_offset;
9681-
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
9682-
9683-
for (int64_t h = h_start; h < h_end; h++) {
9684-
int64_t h_offset = h * h_stride;
9685-
int64_t t_h_offset = t_offset + h_offset;
9686-
int64_t h_2d_offset = h * h_stride_2d;
9687-
9688-
for (int64_t ii = 0; ii < head_size; ii++) {
9689-
int64_t t_h_i_offset = t_h_offset + ii;
9690-
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
9691-
9692-
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
9693-
9694-
float sa = 0;
9695-
{
9696-
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9697-
GGML_F32_VEC ax[GGML_F32_ARR];
9698-
GGML_F32_VEC ay[GGML_F32_ARR];
9699-
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
9700-
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
9701-
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
9702-
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
9703-
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
9704-
}
9581+
for (int64_t h = h_start; h < h_end; h++) {
9582+
int64_t h_offset = h * h_stride;
9583+
int64_t t_h_offset = t_offset + h_offset;
9584+
int64_t h_2d_offset = h * h_stride_2d;
9585+
9586+
for (int64_t ii = 0; ii < head_size; ii++) {
9587+
int64_t t_h_i_offset = t_h_offset + ii;
9588+
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
9589+
9590+
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
9591+
9592+
float sa = 0;
9593+
{
9594+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9595+
GGML_F32_VEC ax[GGML_F32_ARR];
9596+
GGML_F32_VEC ay[GGML_F32_ARR];
9597+
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
9598+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
9599+
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
9600+
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
9601+
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
97059602
}
9706-
GGML_F32_VEC_REDUCE(sa, sum);
97079603
}
9604+
GGML_F32_VEC_REDUCE(sa, sum);
9605+
}
97089606

9709-
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
9607+
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
97109608

9711-
int64_t j = 0;
9712-
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9713-
for (; j < head_size; j += GGML_F32_STEP) {
9714-
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
9715-
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
9716-
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
9609+
int64_t j = 0;
9610+
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9611+
for (; j < head_size; j += GGML_F32_STEP) {
9612+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
9613+
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
9614+
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
97179615

9718-
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
9719-
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
9720-
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
9721-
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
9616+
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
9617+
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
9618+
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
9619+
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
97229620

9723-
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
9621+
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
97249622

9725-
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
9726-
// kv + s * decay + sa * b
9727-
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
9728-
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
9729-
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
9623+
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
9624+
// kv + s * decay + sa * b
9625+
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
9626+
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
9627+
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
97309628

9731-
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
9732-
}
9733-
}
9734-
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
9735-
9736-
// There shouldn't be left-overs though.
9737-
for (; j < head_size; j++) {
9738-
int64_t t_h_j_offset = t_h_offset + j;
9739-
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9740-
9741-
float r_val = r[t_h_j_offset];
9742-
float w_val = w[t_h_j_offset];
9743-
float k_val = k[t_h_j_offset];
9744-
float b_val = b[t_h_j_offset];
9745-
float kv_val = v[t_h_i_offset] * k_val;
9746-
9747-
float prev_state_val = state_prev[h_2d_i_j_offset];
9748-
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9749-
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
9629+
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
97509630
}
97519631
}
9632+
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
9633+
9634+
// There shouldn't be left-overs though.
9635+
for (; j < head_size; j++) {
9636+
int64_t t_h_j_offset = t_h_offset + j;
9637+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
9638+
9639+
float r_val = r[t_h_j_offset];
9640+
float w_val = w[t_h_j_offset];
9641+
float k_val = k[t_h_j_offset];
9642+
float b_val = b[t_h_j_offset];
9643+
float kv_val = v[t_h_i_offset] * k_val;
9644+
9645+
float prev_state_val = state_prev[h_2d_i_j_offset];
9646+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
9647+
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
9648+
}
97529649
}
97539650
}
9754-
#endif
9651+
}
97559652
#else
97569653
for (int64_t t = 0; t < T; t++) {
97579654
int64_t t_offset = t * t_stride;

0 commit comments

Comments
 (0)