Skip to content

Commit c0344b8

Browse files
committed
F32-vec-SVE
1 parent 64bf15f commit c0344b8

File tree

1 file changed

+141
-70
lines changed

1 file changed

+141
-70
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 141 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -7963,6 +7963,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
79637963
#define GGML_F32X_MUL GGML_F32x16_MUL
79647964
#define GGML_F32X_FMA GGML_F32x16_FMA
79657965
#define WKV_VECTOR_SIZE 16
7966+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
7967+
#define GGML_F32X GGML_F32xt
7968+
#define GGML_F32X_SET1 GGML_F32xt_SET1
7969+
#define GGML_F32X_LOAD GGML_F32xt_LOAD
7970+
#define GGML_F32X_STORE GGML_F32xt_STORE
7971+
#define GGML_F32X_MUL GGML_F32xt_MUL
7972+
#define GGML_F32X_FMA GGML_F32xt_FMA
7973+
#define WKV_VECTOR_SIZE 8
79667974
#elif defined(__ARM_NEON) && defined(__aarch64__)
79677975
#define GGML_F32X GGML_F32x4
79687976
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -7973,8 +7981,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
79737981
#define WKV_VECTOR_SIZE 4
79747982
#endif
79757983

7984+
int wkv_vector_size;
79767985
#ifdef WKV_VECTOR_SIZE
7977-
const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
7986+
#if defined(__ARM_FEATURE_SVE)
7987+
wkv_vector_size = svcntw();
7988+
#else
7989+
wkv_vector_size = WKV_VECTOR_SIZE;
7990+
#endif
7991+
const int64_t vec_count = head_size / wkv_vector_size;
79787992

79797993
for (int64_t t = 0; t < T; t++) {
79807994
size_t t_offset = t * t_stride;
@@ -8004,7 +8018,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80048018
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
80058019

80068020
for (int64_t j = 0; j < vec_count; j++) {
8007-
size_t base_j = j * WKV_VECTOR_SIZE;
8021+
size_t base_j = j * wkv_vector_size;
80088022
size_t t_h_j_offset = t_h_offset + base_j;
80098023
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
80108024

@@ -8029,7 +8043,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
80298043
}
80308044

80318045
// Handle remaining elements, this will not be used.
8032-
for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
8046+
for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
80338047
size_t t_h_j_offset = t_h_offset + j;
80348048
size_t h_2d_i_j_offset = h_2d_i_offset + j;
80358049
float v_val = v[t_h_j_offset];
@@ -8165,6 +8179,14 @@ static void ggml_compute_forward_gla_f32(
81658179
#define GGML_F32X_MUL GGML_F32x16_MUL
81668180
#define GGML_F32X_FMA GGML_F32x16_FMA
81678181
#define GLA_VECTOR_SIZE 16
8182+
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
8183+
#define GGML_F32X GGML_F32xt
8184+
#define GGML_F32X_SET1 GGML_F32xt_SET1
8185+
#define GGML_F32X_LOAD GGML_F32xt_LOAD
8186+
#define GGML_F32X_STORE GGML_F32xt_STORE
8187+
#define GGML_F32X_MUL GGML_F32xt_MUL
8188+
#define GGML_F32X_FMA GGML_F32xt_FMA
8189+
#define GLA_VECTOR_SIZE 8
81688190
#elif defined(__ARM_NEON) && defined(__aarch64__)
81698191
#define GGML_F32X GGML_F32x4
81708192
#define GGML_F32X_SET1 GGML_F32x4_SET1
@@ -8175,8 +8197,14 @@ static void ggml_compute_forward_gla_f32(
81758197
#define GLA_VECTOR_SIZE 4
81768198
#endif
81778199

8200+
int gla_vector_size;
81788201
#ifdef GLA_VECTOR_SIZE
8179-
const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
8202+
#if defined(__ARM_FEATURE_SVE)
8203+
gla_vector_size = svcntw();
8204+
#else
8205+
gla_vector_size = GLA_VECTOR_SIZE;
8206+
#endif
8207+
const int64_t vec_count = head_size / gla_vector_size;
81808208

81818209
for (int64_t t = 0; t < T; t++) {
81828210
size_t t_offset = t * t_stride;
@@ -8203,7 +8231,7 @@ static void ggml_compute_forward_gla_f32(
82038231
GGML_F32X g_vec = GGML_F32X_SET1(g_val);
82048232

82058233
for (int64_t j = 0; j < vec_count; j++) {
8206-
size_t base_j = j * GLA_VECTOR_SIZE;
8234+
size_t base_j = j * gla_vector_size;
82078235
size_t t_h_j_offset = t_h_offset + base_j;
82088236
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
82098237

@@ -8227,7 +8255,7 @@ static void ggml_compute_forward_gla_f32(
82278255
}
82288256

82298257
// Handle remaining elements, this will not be used.
8230-
for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
8258+
for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
82318259
size_t t_h_j_offset = t_h_offset + j;
82328260
size_t h_2d_i_j_offset = h_2d_i_offset + j;
82338261
float v_val = v[t_h_j_offset];
@@ -8336,83 +8364,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
83368364
int64_t h_stride_2d = head_size * head_size;
83378365

83388366
#if defined(GGML_SIMD)
8339-
for (int64_t t = 0; t < T; t++) {
8340-
int64_t t_offset = t * t_stride;
8341-
int64_t state_offset = head_size * C * (t / (T / n_seqs));
8342-
float * state_cur = state + state_offset;
8343-
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8344-
8345-
for (int64_t h = h_start; h < h_end; h++) {
8346-
int64_t h_offset = h * h_stride;
8347-
int64_t t_h_offset = t_offset + h_offset;
8348-
int64_t h_2d_offset = h * h_stride_2d;
8349-
8350-
for (int64_t ii = 0; ii < head_size; ii++) {
8351-
int64_t t_h_i_offset = t_h_offset + ii;
8352-
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8353-
8354-
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
8367+
#if defined(__ARM_FEATURE_SVE)
8368+
// scalar Route to scalar implementation //TODO: Write SVE code
8369+
for (int64_t t = 0; t < T; t++) {
8370+
int64_t t_offset = t * t_stride;
8371+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
8372+
float * state_cur = state + state_offset;
8373+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8374+
8375+
for (int64_t h = h_start; h < h_end; h++) {
8376+
int64_t h_offset = h * h_stride;
8377+
int64_t t_h_offset = t_offset + h_offset;
8378+
int64_t h_2d_offset = h * h_stride_2d;
8379+
8380+
for (int64_t i = 0; i < head_size; i++) {
8381+
int64_t t_h_i_offset = t_h_offset + i;
8382+
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
8383+
8384+
float v_val = v[t_h_i_offset];
8385+
8386+
float sa = 0, result = 0;
8387+
for (int64_t j = 0; j < head_size; j++) {
8388+
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
8389+
}
83558390

8356-
float sa = 0;
8357-
{
8358-
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8359-
GGML_F32_VEC ax[GGML_F32_ARR];
8360-
GGML_F32_VEC ay[GGML_F32_ARR];
8361-
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
8362-
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8363-
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
8364-
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8365-
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
8366-
}
8391+
for (int64_t j = 0; j < head_size; j++) {
8392+
int64_t t_h_j_offset = t_h_offset + j;
8393+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8394+
8395+
float r_val = r[t_h_j_offset];
8396+
float w_val = w[t_h_j_offset];
8397+
float k_val = k[t_h_j_offset];
8398+
float b_val = b[t_h_j_offset];
8399+
float kv_val = v_val * k_val;
8400+
float prev_state_val = state_prev[h_2d_i_j_offset];
8401+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8402+
result += state_cur[h_2d_i_j_offset] * r_val;
83678403
}
8368-
GGML_F32_VEC_REDUCE(sa, sum);
8404+
dst_data[t_h_i_offset] = result;
83698405
}
8406+
}
8407+
}
8408+
#else
8409+
for (int64_t t = 0; t < T; t++) {
8410+
int64_t t_offset = t * t_stride;
8411+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
8412+
float * state_cur = state + state_offset;
8413+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
8414+
8415+
for (int64_t h = h_start; h < h_end; h++) {
8416+
int64_t h_offset = h * h_stride;
8417+
int64_t t_h_offset = t_offset + h_offset;
8418+
int64_t h_2d_offset = h * h_stride_2d;
8419+
8420+
for (int64_t ii = 0; ii < head_size; ii++) {
8421+
int64_t t_h_i_offset = t_h_offset + ii;
8422+
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
8423+
8424+
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
8425+
8426+
float sa = 0;
8427+
{
8428+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8429+
GGML_F32_VEC ax[GGML_F32_ARR];
8430+
GGML_F32_VEC ay[GGML_F32_ARR];
8431+
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
8432+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8433+
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
8434+
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
8435+
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
8436+
}
8437+
}
8438+
GGML_F32_VEC_REDUCE(sa, sum);
8439+
}
83708440

8371-
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
8441+
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
83728442

8373-
int64_t j = 0;
8374-
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8375-
for (; j < head_size; j += GGML_F32_STEP) {
8376-
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8377-
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8378-
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
8443+
int64_t j = 0;
8444+
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
8445+
for (; j < head_size; j += GGML_F32_STEP) {
8446+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
8447+
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
8448+
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
83798449

8380-
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
8381-
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
8382-
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
8383-
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
8450+
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
8451+
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
8452+
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
8453+
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
83848454

8385-
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
8455+
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
83868456

8387-
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
8388-
// kv + s * decay + sa * b
8389-
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
8390-
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
8391-
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
8457+
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
8458+
// kv + s * decay + sa * b
8459+
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
8460+
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
8461+
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
83928462

8393-
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
8463+
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
8464+
}
8465+
}
8466+
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
8467+
8468+
// There shouldn't be left-overs though.
8469+
for (; j < head_size; j++) {
8470+
int64_t t_h_j_offset = t_h_offset + j;
8471+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8472+
8473+
float r_val = r[t_h_j_offset];
8474+
float w_val = w[t_h_j_offset];
8475+
float k_val = k[t_h_j_offset];
8476+
float b_val = b[t_h_j_offset];
8477+
float kv_val = v[t_h_i_offset] * k_val;
8478+
8479+
float prev_state_val = state_prev[h_2d_i_j_offset];
8480+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8481+
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
83948482
}
8395-
}
8396-
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
8397-
8398-
// There shouldn't be left-overs though.
8399-
for (; j < head_size; j++) {
8400-
int64_t t_h_j_offset = t_h_offset + j;
8401-
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
8402-
8403-
float r_val = r[t_h_j_offset];
8404-
float w_val = w[t_h_j_offset];
8405-
float k_val = k[t_h_j_offset];
8406-
float b_val = b[t_h_j_offset];
8407-
float kv_val = v[t_h_i_offset] * k_val;
8408-
8409-
float prev_state_val = state_prev[h_2d_i_j_offset];
8410-
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
8411-
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
84128483
}
84138484
}
84148485
}
8415-
}
8486+
#endif
84168487
#else
84178488
for (int64_t t = 0; t < T; t++) {
84188489
int64_t t_offset = t * t_stride;

0 commit comments

Comments
 (0)