@@ -198,32 +198,93 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
198198 ggml_float sumf = 0.0 ;
199199
200200#if defined(GGML_SIMD)
201- const int np = (n & ~(GGML_F16_STEP - 1 ));
202-
203- GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
201+ #if defined(__ARM_FEATURE_SVE)
202+ const int sve_register_length = svcntb () * 8 ;
203+ const int ggml_f16_epr = sve_register_length / 16 ; // running when 16
204+ const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
205+
206+ const int np= (n & ~(ggml_f16_step - 1 ));
207+ svfloat16_t sum1 = svdup_n_f16 (0 .0f );
208+ svfloat16_t sum2 = svdup_n_f16 (0 .0f );
209+ svfloat16_t sum3 = svdup_n_f16 (0 .0f );
210+ svfloat16_t sum4 = svdup_n_f16 (0 .0f );
211+
212+ svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
213+ svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
214+ for (int i = 0 ; i < np; i += ggml_f16_step) {
215+ ax1 = GGML_F16x_VEC_LOAD (x + i + 0 * ggml_f16_epr, 0 );
216+ ay1 = GGML_F16x_VEC_LOAD (y + i + 0 * ggml_f16_epr, 0 );
217+ sum1 = GGML_F16x_VEC_FMA (ax1, ay1, sum1);
218+
219+ ax2 = GGML_F16x_VEC_LOAD (x + i + 1 * ggml_f16_epr, 1 );
220+ ay2 = GGML_F16x_VEC_LOAD (y + i + 1 * ggml_f16_epr, 1 );
221+ sum2 = GGML_F16x_VEC_FMA (ax2, ay2, sum2);
222+
223+ ax3 = GGML_F16x_VEC_LOAD (x + i + 2 * ggml_f16_epr, 2 );
224+ ay3 = GGML_F16x_VEC_LOAD (y + i + 2 * ggml_f16_epr, 2 );
225+ sum3 = GGML_F16x_VEC_FMA (ax3, ay3, sum3);
226+
227+ ax4 = GGML_F16x_VEC_LOAD (x + i + 3 * ggml_f16_epr, 3 );
228+ ay4 = GGML_F16x_VEC_LOAD (y + i + 3 * ggml_f16_epr, 3 );
229+ sum4 = GGML_F16x_VEC_FMA (ax4, ay4, sum4);
230+
231+ ax5 = GGML_F16x_VEC_LOAD (x + i + 4 * ggml_f16_epr, 4 );
232+ ay5 = GGML_F16x_VEC_LOAD (y + i + 4 * ggml_f16_epr, 4 );
233+ sum1 = GGML_F16x_VEC_FMA (ax5, ay5, sum1);
234+
235+ ax6 = GGML_F16x_VEC_LOAD (x + i + 5 * ggml_f16_epr, 5 );
236+ ay6 = GGML_F16x_VEC_LOAD (y + i + 5 * ggml_f16_epr, 5 );
237+ sum2 = GGML_F16x_VEC_FMA (ax6, ay6, sum2);
238+
239+ ax7 = GGML_F16x_VEC_LOAD (x + i + 6 * ggml_f16_epr, 6 );
240+ ay7 = GGML_F16x_VEC_LOAD (y + i + 6 * ggml_f16_epr, 6 );
241+ sum3 = GGML_F16x_VEC_FMA (ax7, ay7, sum3);
242+
243+ ax8 = GGML_F16x_VEC_LOAD (x + i + 7 * ggml_f16_epr, 7 );
244+ ay8 = GGML_F16x_VEC_LOAD (y + i + 7 * ggml_f16_epr, 7 );
245+ sum4 = GGML_F16x_VEC_FMA (ax8, ay8, sum4);
246+ }
204247
205- GGML_F16_VEC ax[GGML_F16_ARR];
206- GGML_F16_VEC ay[GGML_F16_ARR];
248+ const int np2 = (n & ~(ggml_f16_epr - 1 )); // round down to multiple of 8
249+ for (int k = np; k < np2; k += ggml_f16_epr) {
250+ svfloat16_t rx = GGML_F16x_VEC_LOAD (x + k, 0 );
251+ svfloat16_t ry = GGML_F16x_VEC_LOAD (y + k, 0 );
252+ sum1 = GGML_F16x_VEC_FMA (rx, ry, sum1);
253+ }
207254
208- for ( int i = 0 ; i < np; i += GGML_F16_STEP ) {
209- for ( int j = 0 ; j < GGML_F16_ARR; j++) {
210- ax[j] = GGML_F16_VEC_LOAD (x + i + j*GGML_F16_EPR, j );
211- ay[j] = GGML_F16_VEC_LOAD (y + i + j*GGML_F16_EPR, j );
255+ if (np2 < n ) {
256+ svbool_t pg = svwhilelt_b16 (np2,n);
257+ svfloat16_t hx = svld1_f16 (pg, ( const __fp16 *)(x + np2) );
258+ svfloat16_t hy = svld1_f16 (pg, ( const __fp16 *)(y + np2) );
212259
213- sum[j] = GGML_F16_VEC_FMA (sum[j], ax[j], ay[j] );
260+ sum1 = svmad_f16_x (pg, hx, hy, sum1 );
214261 }
215- }
262+ GGML_F16x_VEC_REDUCE (sumf, sum1, sum2, sum3, sum4);
263+ #else
264+ const int np = (n & ~(GGML_F16_STEP - 1 ));
216265
217- // reduce sum0..sum3 to sum0
218- GGML_F16_VEC_REDUCE (sumf, sum);
266+ GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
219267
220- // leftovers
221- for (int i = np; i < n; ++i) {
222- sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32 (x[i])*GGML_CPU_FP16_TO_FP32 (y[i]));
223- }
268+ GGML_F16_VEC ax[GGML_F16_ARR];
269+ GGML_F16_VEC ay[GGML_F16_ARR];
270+
271+ for (int i = 0 ; i < np; i += GGML_F16_STEP) {
272+ for (int j = 0 ; j < GGML_F16_ARR; j++) {
273+ ax[j] = GGML_F16_VEC_LOAD (x + i + j*GGML_F16_EPR, j);
274+ ay[j] = GGML_F16_VEC_LOAD (y + i + j*GGML_F16_EPR, j);
224275
225- // if you hit this, you are likely running outside the FP range
226- assert (!isnan (sumf) && !isinf (sumf));
276+ sum[j] = GGML_F16_VEC_FMA (sum[j], ax[j], ay[j]);
277+ }
278+ }
279+
280+ // reduce sum0..sum3 to sum0
281+ GGML_F16_VEC_REDUCE (sumf, sum);
282+
283+ // leftovers
284+ for (int i = np; i < n; ++i) {
285+ sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32 (x[i])*GGML_CPU_FP16_TO_FP32 (y[i]));
286+ }
287+ #endif
227288#else
228289 for (int i = 0 ; i < n; ++i) {
229290 sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32 (x[i])*GGML_CPU_FP16_TO_FP32 (y[i]));
0 commit comments