Skip to content

Commit 6ab4e50

Browse files
xctanggerganov
andauthored
ggml-cpu : add RISC-V Zvfh impl for ggml_vec_mad_f16 (#17448)
* ggml-cpu : add RISC-V Zvfh impl for ggml_vec_mad_f16 * ggml-cpu : dedup scalar impl * Update ggml/src/ggml-cpu/vec.h --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 2336cc4 commit 6ab4e50

File tree

1 file changed

+84
-85
lines changed

1 file changed

+84
-85
lines changed

ggml/src/ggml-cpu/vec.h

Lines changed: 84 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -397,119 +397,118 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
397397
}
398398

399399
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
400-
#if defined(GGML_SIMD)
401-
#if defined(__ARM_FEATURE_SVE)
402-
const int sve_register_length = svcntb() * 8;
403-
const int ggml_f16_epr = sve_register_length / 16;
404-
const int ggml_f16_step = 8 * ggml_f16_epr;
400+
#if defined(GGML_SIMD) && defined(__ARM_FEATURE_SVE)
401+
const int sve_register_length = svcntb() * 8;
402+
const int ggml_f16_epr = sve_register_length / 16;
403+
const int ggml_f16_step = 8 * ggml_f16_epr;
405404

406-
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
405+
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
407406

408-
const int np= (n & ~(ggml_f16_step - 1));
407+
int np = (n & ~(ggml_f16_step - 1));
409408

410-
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
411-
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
412-
for (int i = 0; i < np; i += ggml_f16_step) {
413-
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
414-
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
415-
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
409+
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
410+
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
411+
for (int i = 0; i < np; i += ggml_f16_step) {
412+
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
413+
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
414+
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
416415

417-
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
416+
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
418417

419-
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
420-
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
421-
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
418+
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
419+
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
420+
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
422421

423-
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
422+
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
424423

425-
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
426-
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
427-
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
424+
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
425+
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
426+
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
428427

429-
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
428+
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
430429

431-
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
432-
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
433-
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
430+
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
431+
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
432+
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
434433

435-
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
434+
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
436435

437-
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
438-
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
439-
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
436+
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
437+
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
438+
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
440439

441-
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
440+
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
442441

443-
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
444-
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
445-
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
442+
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
443+
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
444+
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
446445

447-
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
446+
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
448447

449-
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
450-
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
451-
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
448+
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
449+
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
450+
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
452451

453-
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
452+
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
454453

455-
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
456-
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
457-
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
454+
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
455+
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
456+
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
458457

459-
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
460-
}
461-
const int np2 = (n & ~(ggml_f16_epr - 1));
462-
for (int k = np; k < np2; k += ggml_f16_epr) {
463-
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
464-
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
465-
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
466-
467-
GGML_F16x_VEC_STORE(y + k, ry, 0);
468-
}
469-
470-
if (np2 < n) {
471-
svbool_t pg = svwhilelt_b16(np2, n);
472-
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
473-
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
474-
hy = svmad_f16_x(pg, hx, vx, hy);
475-
svst1_f16(pg, (__fp16 *)(y + np2), hy);
476-
}
458+
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
459+
}
460+
const int np2 = (n & ~(ggml_f16_epr - 1));
461+
for (int k = np; k < np2; k += ggml_f16_epr) {
462+
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
463+
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
464+
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
477465

478-
#elif defined(__riscv_v_intrinsic)
479-
// todo: RVV impl
480-
// scalar
481-
for (int i = 0; i < n; ++i) {
482-
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
483-
}
484-
#else
485-
const int np = (n & ~(GGML_F16_STEP - 1));
466+
GGML_F16x_VEC_STORE(y + k, ry, 0);
467+
}
486468

487-
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
469+
if (np2 < n) {
470+
svbool_t pg = svwhilelt_b16(np2, n);
471+
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
472+
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
473+
hy = svmad_f16_x(pg, hx, vx, hy);
474+
svst1_f16(pg, (__fp16 *)(y + np2), hy);
475+
}
476+
np = n;
477+
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
478+
const int np = n;
479+
_Float16 hv = (_Float16)v;
480+
for (int i = 0, avl; i < n; i += avl) {
481+
avl = __riscv_vsetvl_e16m8(n - i);
482+
vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl);
483+
vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl);
484+
vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl);
485+
__riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl);
486+
}
487+
#elif defined(GGML_SIMD)
488+
const int np = (n & ~(GGML_F16_STEP - 1));
488489

489-
GGML_F16_VEC ax[GGML_F16_ARR];
490-
GGML_F16_VEC ay[GGML_F16_ARR];
490+
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
491491

492-
for (int i = 0; i < np; i += GGML_F16_STEP) {
493-
for (int j = 0; j < GGML_F16_ARR; j++) {
494-
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
495-
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
496-
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
492+
GGML_F16_VEC ax[GGML_F16_ARR];
493+
GGML_F16_VEC ay[GGML_F16_ARR];
497494

498-
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
499-
}
500-
}
495+
for (int i = 0; i < np; i += GGML_F16_STEP) {
496+
for (int j = 0; j < GGML_F16_ARR; j++) {
497+
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
498+
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
499+
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
501500

502-
// leftovers
503-
for (int i = np; i < n; ++i) {
504-
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
501+
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
505502
}
506-
#endif
503+
}
507504
#else
508-
// scalar
509-
for (int i = 0; i < n; ++i) {
505+
const int np = 0;
506+
#endif
507+
508+
// leftovers
509+
for (int i = np; i < n; ++i) {
510510
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
511511
}
512-
#endif
513512
}
514513

515514
// xs and vs are byte strides of x and v

0 commit comments

Comments
 (0)