@@ -191,8 +191,8 @@ EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) {
191191
192192// --- single precision floats
193193
194- // Horizontal add of all 8 elements in a __m256 register
195- static inline float horizontal_sum_avx2 ( __m256 v ) {
194+ // Horizontally add 8 float32 elements in a __m256 register
195+ static inline float hsum_f32_8 ( const __m256 v ) {
196196 // First, add the low and high 128-bit lanes
197197 __m128 low = _mm256_castps256_ps128 (v ); // lower 128 bits
198198 __m128 high = _mm256_extractf128_ps (v , 1 ); // upper 128 bits
@@ -261,9 +261,9 @@ EXPORT float cosf32(const float *a, const float *b, size_t elementCount) {
261261 __m256 norm_a_total = _mm256_add_ps (_mm256_add_ps (norm_a0 , norm_a1 ), _mm256_add_ps (norm_a2 , norm_a3 ));
262262 __m256 norm_b_total = _mm256_add_ps (_mm256_add_ps (norm_b0 , norm_b1 ), _mm256_add_ps (norm_b2 , norm_b3 ));
263263
264- float dot_result = horizontal_sum_avx2 (dot_total );
265- float norm_a_result = horizontal_sum_avx2 (norm_a_total );
266- float norm_b_result = horizontal_sum_avx2 (norm_b_total );
264+ float dot_result = hsum_f32_8 (dot_total );
265+ float norm_a_result = hsum_f32_8 (norm_a_total );
266+ float norm_b_result = hsum_f32_8 (norm_b_total );
267267
268268 // Handle remaining tail with scalar loop
269269 for (; i < elementCount ; ++ i ) {
@@ -302,7 +302,7 @@ EXPORT float dotf32(const float *a, const float *b, size_t elementCount) {
302302
303303 // Combine all partial sums
304304 __m256 total_sum = _mm256_add_ps (_mm256_add_ps (acc0 , acc1 ), _mm256_add_ps (acc2 , acc3 ));
305- float result = horizontal_sum_avx2 (total_sum );
305+ float result = hsum_f32_8 (total_sum );
306306
307307 for (; i < elementCount ; ++ i ) {
308308 result += a [i ] * b [i ];
@@ -337,7 +337,7 @@ EXPORT float sqrf32(const float *a, const float *b, size_t elementCount) {
337337
338338 // reduce all partial sums
339339 __m256 total_sum = _mm256_add_ps (_mm256_add_ps (sum0 , sum1 ), _mm256_add_ps (sum2 , sum3 ));
340- float result = horizontal_sum_avx2 (total_sum );
340+ float result = hsum_f32_8 (total_sum );
341341
342342 for (; i < elementCount ; ++ i ) {
343343 float diff = a [i ] - b [i ];
0 commit comments