Skip to content

Commit a57fc34

Browse files
committed
Added sve implementation for vec_dot_fp16 Kernel
1 parent 2241453 commit a57fc34

File tree

3 files changed

+381
-67
lines changed

3 files changed

+381
-67
lines changed

ggml/src/ggml-cpu/simd-mappings.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,47 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
220220
#define GGML_F32_VEC_MUL GGML_F32xt_MUL
221221
#define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE
222222

223+
// F16 SVE
224+
#define DEFAULT_PG32 svptrue_b32()
225+
#define DEFAULT_PG16 svptrue_b16()
226+
227+
#define GGML_F32Cxt svfloat16_t
228+
#define GGML_F32Cxt_ZERO svdup_n_f16(0.0f)
229+
#define GGML_F32Cxt_SET1(x) svdup_n_f16(x)
230+
#define GGML_F32Cxt_LOAD(p) svld1_f16(DEFAULT_PG16, (const __fp16 *)(p))
231+
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
232+
233+
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, a, b, c)
234+
#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__)
235+
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
236+
#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__)
237+
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
238+
#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__)
239+
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
240+
241+
#define GGML_F16x_VEC GGML_F32Cxt
242+
#define GGML_F16x_VEC_ZERO GGML_F32Cxt_ZERO
243+
#define GGML_F16x_VEC_SET1 GGML_F32Cxt_SET1
244+
#define GGML_F16x_VEC_LOAD(p, i) GGML_F32Cxt_LOAD(p)
245+
#define GGML_F16x_VEC_STORE(p, r, i) GGML_F32Cxt_STORE((__fp16 *)(p), r)
246+
#define GGML_F16x_VEC_FMA GGML_F32Cxt_FMA
247+
#define GGML_F16x_VEC_ADD GGML_F32Cxt_ADD
248+
#define GGML_F16x_VEC_MUL GGML_F32Cxt_MUL
249+
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
250+
251+
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
252+
#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__)
253+
254+
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
255+
{ \
256+
sum1 = svadd_f16_x(pg16, sum1, sum2); \
257+
sum3 = svadd_f16_x(pg16, sum3, sum4); \
258+
sum1 = svadd_f16_x(pg16, sum1, sum3); \
259+
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \
260+
(res) = (ggml_float) sum_f16; \
261+
}
262+
#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__)
263+
223264
// F16 NEON
224265

225266
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)

ggml/src/ggml-cpu/vec.cpp

Lines changed: 80 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -198,32 +198,93 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
198198
ggml_float sumf = 0.0;
199199

200200
#if defined(GGML_SIMD)
201-
const int np = (n & ~(GGML_F16_STEP - 1));
202-
203-
GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
201+
#if defined(__ARM_FEATURE_SVE)
202+
const int sve_register_length = svcntb() * 8;
203+
const int ggml_f16_epr = sve_register_length / 16; // running when 16
204+
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
205+
206+
const int np= (n & ~(ggml_f16_step - 1));
207+
svfloat16_t sum1 = svdup_n_f16(0.0f);
208+
svfloat16_t sum2 = svdup_n_f16(0.0f);
209+
svfloat16_t sum3 = svdup_n_f16(0.0f);
210+
svfloat16_t sum4 = svdup_n_f16(0.0f);
211+
212+
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
213+
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
214+
for (int i = 0; i < np; i += ggml_f16_step) {
215+
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
216+
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
217+
sum1 = GGML_F16x_VEC_FMA(ax1, ay1, sum1);
218+
219+
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
220+
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
221+
sum2 = GGML_F16x_VEC_FMA(ax2, ay2, sum2);
222+
223+
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
224+
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
225+
sum3 = GGML_F16x_VEC_FMA(ax3, ay3, sum3);
226+
227+
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
228+
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
229+
sum4 = GGML_F16x_VEC_FMA(ax4, ay4, sum4);
230+
231+
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
232+
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
233+
sum1 = GGML_F16x_VEC_FMA(ax5, ay5, sum1);
234+
235+
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
236+
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
237+
sum2 = GGML_F16x_VEC_FMA(ax6, ay6, sum2);
238+
239+
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
240+
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
241+
sum3 = GGML_F16x_VEC_FMA(ax7, ay7, sum3);
242+
243+
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
244+
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
245+
sum4 = GGML_F16x_VEC_FMA(ax8, ay8, sum4);
246+
}
204247

205-
GGML_F16_VEC ax[GGML_F16_ARR];
206-
GGML_F16_VEC ay[GGML_F16_ARR];
248+
const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8
249+
for (int k = np; k < np2; k += ggml_f16_epr) {
250+
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
251+
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
252+
sum1 = GGML_F16x_VEC_FMA(rx, ry, sum1);
253+
}
207254

208-
for (int i = 0; i < np; i += GGML_F16_STEP) {
209-
for (int j = 0; j < GGML_F16_ARR; j++) {
210-
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
211-
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
255+
if (np2 < n) {
256+
svbool_t pg = svwhilelt_b16(np2,n);
257+
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
258+
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
212259

213-
sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
260+
sum1 = svmad_f16_x(pg, hx, hy, sum1);
214261
}
215-
}
262+
GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);
263+
#else
264+
const int np = (n & ~(GGML_F16_STEP - 1));
216265

217-
// reduce sum0..sum3 to sum0
218-
GGML_F16_VEC_REDUCE(sumf, sum);
266+
GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
219267

220-
// leftovers
221-
for (int i = np; i < n; ++i) {
222-
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
223-
}
268+
GGML_F16_VEC ax[GGML_F16_ARR];
269+
GGML_F16_VEC ay[GGML_F16_ARR];
270+
271+
for (int i = 0; i < np; i += GGML_F16_STEP) {
272+
for (int j = 0; j < GGML_F16_ARR; j++) {
273+
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
274+
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
224275

225-
// if you hit this, you are likely running outside the FP range
226-
assert(!isnan(sumf) && !isinf(sumf));
276+
sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
277+
}
278+
}
279+
280+
// reduce sum0..sum3 to sum0
281+
GGML_F16_VEC_REDUCE(sumf, sum);
282+
283+
// leftovers
284+
for (int i = np; i < n; ++i) {
285+
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
286+
}
287+
#endif
227288
#else
228289
for (int i = 0; i < n; ++i) {
229290
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));

0 commit comments

Comments
 (0)