@@ -404,6 +404,72 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
404
404
}
405
405
}
406
406
407
+ ggml_float ggml_vec_cvar_f32 (const int n, float * y, const float * x, const float mean) {
408
+ int i = 0 ;
409
+ ggml_float sum = 0 ;
410
+ // TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
411
+ // ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
412
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
413
+ for (; i + 15 < n; i += 16 ) {
414
+ __m512 val = _mm512_sub_ps (_mm512_loadu_ps (x + i),
415
+ _mm512_set1_ps (mean));
416
+ _mm512_storeu_ps (y + i, val);
417
+ sum += (ggml_float)_mm512_reduce_add_ps (_mm512_mul_ps (val, val));
418
+ }
419
+ #elif defined(__AVX2__) && defined(__FMA__)
420
+ for (; i + 7 < n; i += 8 ) {
421
+ __m256 val = _mm256_sub_ps (_mm256_loadu_ps (x + i),
422
+ _mm256_set1_ps (mean));
423
+ _mm256_storeu_ps (y + i, val);
424
+ val = _mm256_mul_ps (val,val);
425
+ __m128 val2 = _mm_add_ps (_mm256_extractf128_ps (val, 1 ),
426
+ _mm256_castps256_ps128 (val));
427
+ val2 = _mm_add_ps (val2, _mm_movehl_ps (val2, val2));
428
+ val2 = _mm_add_ss (val2, _mm_movehdup_ps (val2));
429
+ sum += (ggml_float)_mm_cvtss_f32 (val2);
430
+ }
431
+ #elif defined(__SSE2__)
432
+ for (; i + 3 < n; i += 4 ) {
433
+ __m128 val = _mm_sub_ps (_mm_loadu_ps (x + i),
434
+ _mm_set1_ps (mean));
435
+ _mm_storeu_ps (y + i, val);
436
+ val = _mm_mul_ps (val, val);
437
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
438
+ val = _mm_add_ps (val, _mm_movehl_ps (val, val));
439
+ val = _mm_add_ss (val, _mm_movehdup_ps (val));
440
+ #else
441
+ __m128 tmp = _mm_shuffle_ps (val, val, _MM_SHUFFLE (2 , 3 , 0 , 1 ));
442
+ val = _mm_add_ps (val, tmp);
443
+ tmp = _mm_movehl_ps (tmp, val);
444
+ val = _mm_add_ss (val, tmp);
445
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
446
+ sum += (ggml_float)_mm_cvtss_f32 (val);
447
+ }
448
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
449
+ for (; i + 3 < n; i += 4 ) {
450
+ float32x4_t val = vsubq_f32 (vld1q_f32 (x + i),
451
+ vdupq_n_f32 (mean));
452
+ vst1q_f32 (y + i, val);
453
+ val = vmulq_f32 (val, val);
454
+ sum += (ggml_float)vaddvq_f32 (val);
455
+ }
456
+ #elif defined(__VXE__) || defined(__VXE2__)
457
+ for (; i + 3 < n; i += 4 ) {
458
+ float32x4_t val = vec_sub (vec_xl (0 , x + i), vec_splats (mean));
459
+ vec_xst (val, 0 , y + i);
460
+ val = vec_mul (val, val);
461
+ sum += (ggml_float)vec_hsum_f32x4 (val);
462
+ }
463
+ #endif
464
+ for (; i < n; ++i) {
465
+ float val = x[i] - mean;
466
+ val *= val;
467
+ sum += (ggml_float)val;
468
+ y[i] = val;
469
+ }
470
+ return sum/n;
471
+ }
472
+
407
473
ggml_float ggml_vec_soft_max_f32 (const int n, float * y, const float * x, float max) {
408
474
int i = 0 ;
409
475
ggml_float sum = 0 ;
0 commit comments