@@ -18,73 +18,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
1818#if defined(GGML_SIMD)
1919 float sumf = 0 .0f ;
2020
21- #if defined(__ARM_FEATURE_SVE)
22- const int sve_register_length = ggml_cpu_get_sve_cnt () * 8 ;
23- const int ggml_f32_epr = sve_register_length / 32 ;// 8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
24- const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers
25-
26- const int np = (n & ~(ggml_f32_step - 1 ));
27- svfloat32_t sum1 = svdup_n_f32 (0 .0f );
28- svfloat32_t sum2 = svdup_n_f32 (0 .0f );
29- svfloat32_t sum3 = svdup_n_f32 (0 .0f );
30- svfloat32_t sum4 = svdup_n_f32 (0 .0f );
31- svfloat32_t sum5 = svdup_n_f32 (0 .0f );
32- svfloat32_t sum6 = svdup_n_f32 (0 .0f );
33- svfloat32_t sum7 = svdup_n_f32 (0 .0f );
34- svfloat32_t sum8 = svdup_n_f32 (0 .0f );
35- svfloat32_t ax1,ax2,ax3,ax4,ax5,ax6,ax7,ax8;
36- svfloat32_t ay1,ay2,ay3,ay4,ay5,ay6,ay7,ay8;
37- for (int i = 0 ; i < np; i += ggml_f32_step) {
38- ax1 = GGML_F32_VEC_LOAD (x + i);
39- ay1 = GGML_F32_VEC_LOAD (y + i);
40- sum1 = GGML_F32_VEC_FMA (sum1, ax1, ay1);
41-
42- ax2 = GGML_F32_VEC_LOAD (x + i + 1 *ggml_f32_epr);
43- ay2 = GGML_F32_VEC_LOAD (y + i + 1 *ggml_f32_epr);
44- sum2 = GGML_F32_VEC_FMA (sum2, ax2, ay2);
45-
46- ax3 = GGML_F32_VEC_LOAD (x + i + 2 *ggml_f32_epr);
47- ay3 = GGML_F32_VEC_LOAD (y + i + 2 *ggml_f32_epr);
48- sum3 = GGML_F32_VEC_FMA (sum3, ax3, ay3);
49-
50- ax4 = GGML_F32_VEC_LOAD (x + i + 3 *ggml_f32_epr);
51- ay4 = GGML_F32_VEC_LOAD (y + i + 3 *ggml_f32_epr);
52- sum4 = GGML_F32_VEC_FMA (sum4, ax4, ay4);
53-
54- ax5 = GGML_F32_VEC_LOAD (x + i + 4 *ggml_f32_epr);
55- ay5 = GGML_F32_VEC_LOAD (y + i + 4 *ggml_f32_epr);
56- sum5 = GGML_F32_VEC_FMA (sum5, ax5, ay5);
57-
58- ax6 = GGML_F32_VEC_LOAD (x + i + 5 *ggml_f32_epr);
59- ay6 = GGML_F32_VEC_LOAD (y + i + 5 *ggml_f32_epr);
60- sum6 = GGML_F32_VEC_FMA (sum6, ax6, ay6);
61-
62- ax7 = GGML_F32_VEC_LOAD (x + i + 6 *ggml_f32_epr);
63- ay7 = GGML_F32_VEC_LOAD (y + i + 6 *ggml_f32_epr);
64- sum7 = GGML_F32_VEC_FMA (sum7, ax7, ay7);
65-
66- ax8 = GGML_F32_VEC_LOAD (x + i + 7 *ggml_f32_epr);
67- ay8 = GGML_F32_VEC_LOAD (y + i + 7 *ggml_f32_epr);
68- sum8 = GGML_F32_VEC_FMA (sum8, ax8, ay8);
69- }
70- // leftovers
71- // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
72- const int np2 = (n & ~(ggml_f32_epr - 1 ));
73- for (int i = np; i < np2; i += ggml_f32_epr) {
74- ax1 = GGML_F32_VEC_LOAD (x + i);
75- ay1 = GGML_F32_VEC_LOAD (y + i);
76- sum1 = GGML_F32_VEC_FMA (sum1, ax1, ay1);
77- }
78- // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
79- if (np2 < n) {
80- svbool_t pg = svwhilelt_b32 (np2, n);
81- ax1 = svld1_f32 (pg, x + np2);
82- ay1 = svld1_f32 (pg, y + np2);
83- sum1 = svmad_f32_m (pg, ax1, ay1, sum1);
84- }
85- // reduce sum1,sum2 to sum1
86- GGML_F32_VEC_REDUCE (sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
87- #elif defined(__riscv_v_intrinsic)
21+ #if defined(__riscv_v_intrinsic)
8822 int vl = __riscv_vsetvlmax_e32m8 ();
8923 vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1 (0 .0f , 1 );
9024 vfloat32m8_t vsum;
@@ -215,69 +149,7 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
215149
216150
217151#if defined(GGML_SIMD)
218- #if defined(__ARM_FEATURE_SVE)
219- const int sve_register_length = svcntb () * 8 ; // get vector length
220- const int ggml_f16_epr = sve_register_length / 16 ; // running when 16
221- const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
222-
223- const int np= (n & ~(ggml_f16_step - 1 ));
224- svfloat16_t sum1 = svdup_n_f16 (0 .0f );
225- svfloat16_t sum2 = svdup_n_f16 (0 .0f );
226- svfloat16_t sum3 = svdup_n_f16 (0 .0f );
227- svfloat16_t sum4 = svdup_n_f16 (0 .0f );
228-
229- svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
230- svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
231- for (int i = 0 ; i < np; i += ggml_f16_step) {
232- ax1 = GGML_F16x_VEC_LOAD (x + i + 0 * ggml_f16_epr, 0 );
233- ay1 = GGML_F16x_VEC_LOAD (y + i + 0 * ggml_f16_epr, 0 );
234- sum1 = GGML_F16x_VEC_FMA (sum1, ax1, ay1);
235-
236- ax2 = GGML_F16x_VEC_LOAD (x + i + 1 * ggml_f16_epr, 1 );
237- ay2 = GGML_F16x_VEC_LOAD (y + i + 1 * ggml_f16_epr, 1 );
238- sum2 = GGML_F16x_VEC_FMA (sum2, ax2, ay2);
239-
240- ax3 = GGML_F16x_VEC_LOAD (x + i + 2 * ggml_f16_epr, 2 );
241- ay3 = GGML_F16x_VEC_LOAD (y + i + 2 * ggml_f16_epr, 2 );
242- sum3 = GGML_F16x_VEC_FMA (sum3, ax3, ay3);
243-
244- ax4 = GGML_F16x_VEC_LOAD (x + i + 3 * ggml_f16_epr, 3 );
245- ay4 = GGML_F16x_VEC_LOAD (y + i + 3 * ggml_f16_epr, 3 );
246- sum4 = GGML_F16x_VEC_FMA (sum4, ax4, ay4);
247-
248- ax5 = GGML_F16x_VEC_LOAD (x + i + 4 * ggml_f16_epr, 4 );
249- ay5 = GGML_F16x_VEC_LOAD (y + i + 4 * ggml_f16_epr, 4 );
250- sum1 = GGML_F16x_VEC_FMA (sum1, ax5, ay5);
251-
252- ax6 = GGML_F16x_VEC_LOAD (x + i + 5 * ggml_f16_epr, 5 );
253- ay6 = GGML_F16x_VEC_LOAD (y + i + 5 * ggml_f16_epr, 5 );
254- sum2 = GGML_F16x_VEC_FMA (sum2, ax6, ay6);
255-
256- ax7 = GGML_F16x_VEC_LOAD (x + i + 6 * ggml_f16_epr, 6 );
257- ay7 = GGML_F16x_VEC_LOAD (y + i + 6 * ggml_f16_epr, 6 );
258- sum3 = GGML_F16x_VEC_FMA (sum3, ax7, ay7);
259-
260- ax8 = GGML_F16x_VEC_LOAD (x + i + 7 * ggml_f16_epr, 7 );
261- ay8 = GGML_F16x_VEC_LOAD (y + i + 7 * ggml_f16_epr, 7 );
262- sum4 = GGML_F16x_VEC_FMA (sum4, ax8, ay8);
263- }
264-
265- const int np2 = (n & ~(ggml_f16_epr - 1 )); // round down to multiple of 8
266- for (int k = np; k < np2; k += ggml_f16_epr) {
267- svfloat16_t rx = GGML_F16x_VEC_LOAD (x + k, 0 );
268- svfloat16_t ry = GGML_F16x_VEC_LOAD (y + k, 0 );
269- sum1 = GGML_F16x_VEC_FMA (sum1, rx, ry);
270- }
271-
272- if (np2 < n) {
273- svbool_t pg = svwhilelt_b16 (np2, n);
274- svfloat16_t hx = svld1_f16 (pg, (const __fp16 *)(x + np2));
275- svfloat16_t hy = svld1_f16 (pg, (const __fp16 *)(y + np2));
276-
277- sum1 = svmad_f16_x (pg, hx, hy, sum1);
278- }
279- GGML_F16x_VEC_REDUCE (sumf, sum1, sum2, sum3, sum4);
280- #elif defined(__riscv_v_intrinsic)
152+ #if defined(__riscv_v_intrinsic)
281153 #if defined(__riscv_zvfh)
282154 int vl = __riscv_vsetvlmax_e32m2 ();
283155 vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1 (0 .0f , 1 );
0 commit comments