Skip to content

Commit 1b1a757

Browse files
committed
Optimize the performance of dot by using universal intrinsics in X86/ARM
1 parent 0d98ce2 commit 1b1a757

File tree

7 files changed

+177
-38
lines changed

7 files changed

+177
-38
lines changed

kernel/generic/dot.c

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,27 +47,59 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y)
4747

4848
if ( (inc_x == 1) && (inc_y == 1) )
4949
{
50-
5150
int n1 = n & -4;
52-
53-
while(i < n1)
51+
#if V_SIMD && !defined(DSDOT)
52+
const int vstep = v_nlanes_f32;
53+
const int unrollx4 = n & (-vstep * 4);
54+
const int unrollx = n & -vstep;
55+
v_f32 vsum0 = v_zero_f32();
56+
v_f32 vsum1 = v_zero_f32();
57+
v_f32 vsum2 = v_zero_f32();
58+
v_f32 vsum3 = v_zero_f32();
59+
while(i < unrollx4)
60+
{
61+
vsum0 = v_muladd_f32(
62+
v_loadu_f32(x + i), v_loadu_f32(y + i), vsum0
63+
);
64+
vsum1 = v_muladd_f32(
65+
v_loadu_f32(x + i + vstep), v_loadu_f32(y + i + vstep), vsum1
66+
);
67+
vsum2 = v_muladd_f32(
68+
v_loadu_f32(x + i + vstep*2), v_loadu_f32(y + i + vstep*2), vsum2
69+
);
70+
vsum3 = v_muladd_f32(
71+
v_loadu_f32(x + i + vstep*3), v_loadu_f32(y + i + vstep*3), vsum3
72+
);
73+
i += vstep*4;
74+
}
75+
vsum0 = v_add_f32(
76+
v_add_f32(vsum0, vsum1), v_add_f32(vsum2 , vsum3)
77+
);
78+
while(i < unrollx)
79+
{
80+
vsum0 = v_muladd_f32(
81+
v_loadu_f32(x + i), v_loadu_f32(y + i), vsum0
82+
);
83+
i += vstep;
84+
}
85+
dot = v_sum_f32(vsum0);
86+
#elif defined(DSDOT)
87+
for (; i < n1; i += 4)
5488
{
55-
56-
#if defined(DSDOT)
5789
dot += (double) y[i] * (double) x[i]
5890
+ (double) y[i+1] * (double) x[i+1]
5991
+ (double) y[i+2] * (double) x[i+2]
6092
+ (double) y[i+3] * (double) x[i+3] ;
93+
}
6194
#else
95+
for (; i < n1; i += 4)
96+
{
6297
dot += y[i] * x[i]
6398
+ y[i+1] * x[i+1]
6499
+ y[i+2] * x[i+2]
65100
+ y[i+3] * x[i+3] ;
66-
#endif
67-
i+=4 ;
68-
69101
}
70-
102+
#endif
71103
while(i < n)
72104
{
73105

kernel/simd/intrin.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ extern "C" {
5151
#include <immintrin.h>
5252
#endif
5353

54+
/** NEON **/
55+
#ifdef HAVE_NEON
56+
#include <arm_neon.h>
57+
#endif
58+
5459
// distribute
5560
#if defined(HAVE_AVX512VL) || defined(HAVE_AVX512BF16)
5661
#include "intrin_avx512.h"
@@ -60,6 +65,10 @@ extern "C" {
6065
#include "intrin_sse.h"
6166
#endif
6267

68+
#ifdef HAVE_NEON
69+
#include "intrin_neon.h"
70+
#endif
71+
6372
#ifndef V_SIMD
6473
#define V_SIMD 0
6574
#define V_SIMD_F64 0

kernel/simd/intrin_avx.h

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#define V_SIMD 256
22
#define V_SIMD_F64 1
3-
/*
4-
Data Type
5-
*/
3+
/***************************
4+
* Data Type
5+
***************************/
66
typedef __m256 v_f32;
77
#define v_nlanes_f32 8
8-
/*
9-
arithmetic
10-
*/
8+
/***************************
9+
* Arithmetic
10+
***************************/
1111
#define v_add_f32 _mm256_add_ps
1212
#define v_mul_f32 _mm256_mul_ps
1313

@@ -20,10 +20,22 @@ arithmetic
2020
{ return v_add_f32(v_mul_f32(a, b), c); }
2121
#endif // !HAVE_FMA3
2222

23-
/*
24-
memory
25-
*/
23+
// Horizontal add: Calculates the sum of all vector elements.
24+
BLAS_FINLINE float v_sum_f32(__m256 a)
25+
{
26+
__m256 sum_halves = _mm256_hadd_ps(a, a);
27+
sum_halves = _mm256_hadd_ps(sum_halves, sum_halves);
28+
__m128 lo = _mm256_castps256_ps128(sum_halves);
29+
__m128 hi = _mm256_extractf128_ps(sum_halves, 1);
30+
__m128 sum = _mm_add_ps(lo, hi);
31+
return _mm_cvtss_f32(sum);
32+
}
33+
34+
/***************************
35+
* memory
36+
***************************/
2637
// unaligned load
2738
#define v_loadu_f32 _mm256_loadu_ps
2839
#define v_storeu_f32 _mm256_storeu_ps
29-
#define v_setall_f32(VAL) _mm256_set1_ps(VAL)
40+
#define v_setall_f32(VAL) _mm256_set1_ps(VAL)
41+
#define v_zero_f32 _mm256_setzero_ps

kernel/simd/intrin_avx512.h

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,35 @@
11
#define V_SIMD 512
22
#define V_SIMD_F64 1
3-
/*
4-
Data Type
5-
*/
3+
/***************************
4+
* Data Type
5+
***************************/
66
typedef __m512 v_f32;
77
#define v_nlanes_f32 16
8-
/*
9-
arithmetic
10-
*/
8+
/***************************
9+
* Arithmetic
10+
***************************/
1111
#define v_add_f32 _mm512_add_ps
1212
#define v_mul_f32 _mm512_mul_ps
1313
// multiply and add, a*b + c
1414
#define v_muladd_f32 _mm512_fmadd_ps
15-
/*
16-
memory
17-
*/
15+
16+
BLAS_FINLINE float v_sum_f32(v_f32 a)
17+
{
18+
__m512 h64 = _mm512_shuffle_f32x4(a, a, _MM_SHUFFLE(3, 2, 3, 2));
19+
__m512 sum32 = _mm512_add_ps(a, h64);
20+
__m512 h32 = _mm512_shuffle_f32x4(sum32, sum32, _MM_SHUFFLE(1, 0, 3, 2));
21+
__m512 sum16 = _mm512_add_ps(sum32, h32);
22+
__m512 h16 = _mm512_permute_ps(sum16, _MM_SHUFFLE(1, 0, 3, 2));
23+
__m512 sum8 = _mm512_add_ps(sum16, h16);
24+
__m512 h4 = _mm512_permute_ps(sum8, _MM_SHUFFLE(2, 3, 0, 1));
25+
__m512 sum4 = _mm512_add_ps(sum8, h4);
26+
return _mm_cvtss_f32(_mm512_castps512_ps128(sum4));
27+
}
28+
/***************************
29+
* memory
30+
***************************/
1831
// unaligned load
1932
#define v_loadu_f32(PTR) _mm512_loadu_ps((const __m512*)(PTR))
2033
#define v_storeu_f32 _mm512_storeu_ps
2134
#define v_setall_f32(VAL) _mm512_set1_ps(VAL)
35+
#define v_zero_f32 _mm512_setzero_ps

kernel/simd/intrin_neon.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#define V_SIMD 128
2+
#ifdef __aarch64__
3+
#define V_SIMD_F64 1
4+
#else
5+
#define V_SIMD_F64 0
6+
#endif
7+
/***************************
8+
* Data Type
9+
***************************/
10+
typedef float32x4_t v_f32;
11+
#define v_nlanes_f32 4
12+
/***************************
13+
* Arithmetic
14+
***************************/
15+
#define v_add_f32 vaddq_f32
16+
#define v_mul_f32 vmulq_f32
17+
18+
// FUSED F32
19+
#ifdef HAVE_VFPV4 // FMA
20+
// multiply and add, a*b + c
21+
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
22+
{ return vfmaq_f32(c, a, b); }
23+
#else
24+
// multiply and add, a*b + c
25+
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
26+
{ return vmlaq_f32(c, a, b); }
27+
#endif
28+
29+
// Horizontal add: Calculates the sum of all vector elements.
30+
BLAS_FINLINE float v_sum_f32(float32x4_t a)
31+
{
32+
float32x2_t r = vadd_f32(vget_high_f32(a), vget_low_f32(a));
33+
return vget_lane_f32(vpadd_f32(r, r), 0);
34+
}
35+
/***************************
36+
* memory
37+
***************************/
38+
// unaligned load
39+
#define v_loadu_f32(a) vld1q_f32((const float*)a)
40+
#define v_storeu_f32 vst1q_f32
41+
#define v_setall_f32(VAL) vdupq_n_f32(VAL)
42+
#define v_zero_f32() vdupq_n_f32(0.0f)

kernel/simd/intrin_sse.h

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#define V_SIMD 128
22
#define V_SIMD_F64 1
3-
/*
4-
Data Type
5-
*/
3+
/***************************
4+
* Data Type
5+
***************************/
66
typedef __m128 v_f32;
77
#define v_nlanes_f32 4
8-
/*
9-
arithmetic
10-
*/
8+
/***************************
9+
* Arithmetic
10+
***************************/
1111
#define v_add_f32 _mm_add_ps
1212
#define v_mul_f32 _mm_mul_ps
1313
#ifdef HAVE_FMA3
@@ -21,10 +21,26 @@ arithmetic
2121
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
2222
{ return v_add_f32(v_mul_f32(a, b), c); }
2323
#endif // HAVE_FMA3
24-
/*
25-
memory
26-
*/
24+
25+
// Horizontal add: Calculates the sum of all vector elements.
26+
BLAS_FINLINE float v_sum_f32(__m128 a)
27+
{
28+
#ifdef HAVE_SSE3
29+
__m128 sum_halves = _mm_hadd_ps(a, a);
30+
return _mm_cvtss_f32(_mm_hadd_ps(sum_halves, sum_halves));
31+
#else
32+
__m128 t1 = _mm_movehl_ps(a, a);
33+
__m128 t2 = _mm_add_ps(a, t1);
34+
__m128 t3 = _mm_shuffle_ps(t2, t2, 1);
35+
__m128 t4 = _mm_add_ss(t2, t3);
36+
return _mm_cvtss_f32(t4);
37+
#endif
38+
}
39+
/***************************
40+
* memory
41+
***************************/
2742
// unaligned load
2843
#define v_loadu_f32 _mm_loadu_ps
2944
#define v_storeu_f32 _mm_storeu_ps
30-
#define v_setall_f32(VAL) _mm_set1_ps(VAL)
45+
#define v_setall_f32(VAL) _mm_set1_ps(VAL)
46+
#define v_zero_f32 _mm_setzero_ps

utest/test_dsdot.c

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,17 @@ CTEST(dsdot,dsdot_n_1)
4747
ASSERT_DBL_NEAR_TOL(res2, res1, DOUBLE_EPS);
4848

4949
}
50+
51+
CTEST(dsdot,dsdot_n_2)
52+
{
53+
float x[] = {0.1F, 0.2F, 0.3F, 0.4F, 0.5F, 0.6F, 0.7F, 0.8F};
54+
float y[] = {0.1F, 0.2F, 0.3F, 0.4F, 0.5F, 0.6F, 0.7F, 0.8F};
55+
blasint incx=1;
56+
blasint incy=1;
57+
blasint n=8;
58+
59+
double res1=0.0f, res2= 2.0400000444054616;
60+
61+
res1=BLASFUNC(dsdot)(&n, &x, &incx, &y, &incy);
62+
ASSERT_DBL_NEAR_TOL(res2, res1, DOUBLE_EPS);
63+
}

0 commit comments

Comments
 (0)