Skip to content

Commit a0c2b20

Browse files
Vithuleppvname
andauthored
ggml: aarch64: Implement SVE F16 kernels for vector functions (ggml-org#15115)
* Added sve implementation for vec_dot_fp16 Kernel * removed white spaces * Added comment * removed white spaces * changed GGML_F16x_VEC_FMA for code consistency * Update vec.h --------- Co-authored-by: vithulep <[email protected]>
1 parent 4b20d8b commit a0c2b20

File tree

3 files changed

+403
-91
lines changed

3 files changed

+403
-91
lines changed

ggml/src/ggml-cpu/simd-mappings.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,47 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
215215
#define GGML_F32_VEC_MUL GGML_F32xt_MUL
216216
#define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE
217217

218+
// F16 SVE
219+
#define DEFAULT_PG32 svptrue_b32()
220+
#define DEFAULT_PG16 svptrue_b16()
221+
222+
#define GGML_F32Cxt svfloat16_t
223+
#define GGML_F32Cxt_ZERO svdup_n_f16(0.0f)
224+
#define GGML_F32Cxt_SET1(x) svdup_n_f16(x)
225+
#define GGML_F32Cxt_LOAD(p) svld1_f16(DEFAULT_PG16, (const __fp16 *)(p))
226+
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
227+
228+
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
229+
#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__)
230+
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
231+
#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__)
232+
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
233+
#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__)
234+
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
235+
236+
#define GGML_F16x_VEC GGML_F32Cxt
237+
#define GGML_F16x_VEC_ZERO GGML_F32Cxt_ZERO
238+
#define GGML_F16x_VEC_SET1 GGML_F32Cxt_SET1
239+
#define GGML_F16x_VEC_LOAD(p, i) GGML_F32Cxt_LOAD(p)
240+
#define GGML_F16x_VEC_STORE(p, r, i) GGML_F32Cxt_STORE((__fp16 *)(p), r)
241+
#define GGML_F16x_VEC_FMA GGML_F32Cxt_FMA
242+
#define GGML_F16x_VEC_ADD GGML_F32Cxt_ADD
243+
#define GGML_F16x_VEC_MUL GGML_F32Cxt_MUL
244+
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
245+
246+
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
247+
#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__)
248+
249+
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
250+
{ \
251+
sum1 = svadd_f16_x(pg16, sum1, sum2); \
252+
sum3 = svadd_f16_x(pg16, sum3, sum4); \
253+
sum1 = svadd_f16_x(pg16, sum1, sum3); \
254+
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \
255+
(res) = (ggml_float) sum_f16; \
256+
}
257+
#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__)
258+
218259
// F16 NEON
219260

220261
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)

ggml/src/ggml-cpu/vec.cpp

Lines changed: 83 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -207,33 +207,97 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
207207

208208
ggml_float sumf = 0.0;
209209

210-
#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
211-
const int np = (n & ~(GGML_F16_STEP - 1));
212210

213-
GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
211+
#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
212+
#if defined(__ARM_FEATURE_SVE)
213+
const int sve_register_length = svcntb() * 8; //get vector length
214+
const int ggml_f16_epr = sve_register_length / 16; // running when 16
215+
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
216+
217+
const int np= (n & ~(ggml_f16_step - 1));
218+
svfloat16_t sum1 = svdup_n_f16(0.0f);
219+
svfloat16_t sum2 = svdup_n_f16(0.0f);
220+
svfloat16_t sum3 = svdup_n_f16(0.0f);
221+
svfloat16_t sum4 = svdup_n_f16(0.0f);
222+
223+
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
224+
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
225+
for (int i = 0; i < np; i += ggml_f16_step) {
226+
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
227+
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
228+
sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1);
229+
230+
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
231+
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
232+
sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2);
233+
234+
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
235+
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
236+
sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3);
237+
238+
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
239+
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
240+
sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4);
241+
242+
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
243+
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
244+
sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5);
245+
246+
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
247+
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
248+
sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6);
249+
250+
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
251+
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
252+
sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7);
253+
254+
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
255+
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
256+
sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8);
257+
}
214258

215-
GGML_F16_VEC ax[GGML_F16_ARR];
216-
GGML_F16_VEC ay[GGML_F16_ARR];
259+
const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8
260+
for (int k = np; k < np2; k += ggml_f16_epr) {
261+
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
262+
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
263+
sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry);
264+
}
217265

218-
for (int i = 0; i < np; i += GGML_F16_STEP) {
219-
for (int j = 0; j < GGML_F16_ARR; j++) {
220-
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
221-
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
266+
if (np2 < n) {
267+
svbool_t pg = svwhilelt_b16(np2, n);
268+
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
269+
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
222270

223-
sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
271+
sum1 = svmad_f16_x(pg, hx, hy, sum1);
224272
}
225-
}
273+
GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);
274+
#else
275+
const int np = (n & ~(GGML_F16_STEP - 1));
226276

227-
// reduce sum0..sum3 to sum0
228-
GGML_F16_VEC_REDUCE(sumf, sum);
277+
GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
229278

230-
// leftovers
231-
for (int i = np; i < n; ++i) {
232-
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
233-
}
279+
GGML_F16_VEC ax[GGML_F16_ARR];
280+
GGML_F16_VEC ay[GGML_F16_ARR];
234281

235-
// if you hit this, you are likely running outside the FP range
236-
assert(!isnan(sumf) && !isinf(sumf));
282+
for (int i = 0; i < np; i += GGML_F16_STEP) {
283+
for (int j = 0; j < GGML_F16_ARR; j++) {
284+
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
285+
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
286+
287+
sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
288+
}
289+
}
290+
291+
// reduce sum0..sum3 to sum0
292+
GGML_F16_VEC_REDUCE(sumf, sum);
293+
294+
// leftovers
295+
for (int i = np; i < n; ++i) {
296+
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
297+
}
298+
// if you hit this, you are likely running outside the FP range
299+
assert(!isnan(sumf) && !isinf(sumf));
300+
#endif
237301
#else
238302
for (int i = 0; i < n; ++i) {
239303
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));

0 commit comments

Comments
 (0)