Skip to content

Commit be075d5

Browse files
authored
Merge pull request #2983 from Qiyu8/optimize-srot
Optimize the performance of rot by using universal intrinsics
2 parents d341a0f + b00a0de commit be075d5

File tree

7 files changed

+175
-3
lines changed

7 files changed

+175
-3
lines changed

kernel/simd/intrin.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ extern "C" {
4747
#endif
4848

4949
/** AVX **/
50-
#ifdef HAVE_AVX
50+
#if defined(HAVE_AVX) || defined(HAVE_FMA3)
5151
#include <immintrin.h>
5252
#endif
5353

kernel/simd/intrin_avx.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,29 @@ typedef __m256d v_f64;
1212
***************************/
1313
#define v_add_f32 _mm256_add_ps
1414
#define v_add_f64 _mm256_add_pd
15+
#define v_sub_f32 _mm256_sub_ps
16+
#define v_sub_f64 _mm256_sub_pd
1517
#define v_mul_f32 _mm256_mul_ps
1618
#define v_mul_f64 _mm256_mul_pd
1719

1820
#ifdef HAVE_FMA3
1921
// multiply and add, a*b + c
2022
#define v_muladd_f32 _mm256_fmadd_ps
2123
#define v_muladd_f64 _mm256_fmadd_pd
24+
// multiply and subtract, a*b - c
25+
#define v_mulsub_f32 _mm256_fmsub_ps
26+
#define v_mulsub_f64 _mm256_fmsub_pd
2227
#else
2328
// multiply and add, a*b + c
2429
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
2530
{ return v_add_f32(v_mul_f32(a, b), c); }
2631
BLAS_FINLINE v_f64 v_muladd_f64(v_f64 a, v_f64 b, v_f64 c)
2732
{ return v_add_f64(v_mul_f64(a, b), c); }
33+
// multiply and subtract, a*b - c
34+
BLAS_FINLINE v_f32 v_mulsub_f32(v_f32 a, v_f32 b, v_f32 c)
35+
{ return v_sub_f32(v_mul_f32(a, b), c); }
36+
BLAS_FINLINE v_f64 v_mulsub_f64(v_f64 a, v_f64 b, v_f64 c)
37+
{ return v_sub_f64(v_mul_f64(a, b), c); }
2838
#endif // !HAVE_FMA3
2939

3040
// Horizontal add: Calculates the sum of all vector elements.

kernel/simd/intrin_avx512.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,16 @@ typedef __m512d v_f64;
1212
***************************/
1313
#define v_add_f32 _mm512_add_ps
1414
#define v_add_f64 _mm512_add_pd
15+
#define v_sub_f32 _mm512_sub_ps
16+
#define v_sub_f64 _mm512_sub_pd
1517
#define v_mul_f32 _mm512_mul_ps
1618
#define v_mul_f64 _mm512_mul_pd
1719
// multiply and add, a*b + c
1820
#define v_muladd_f32 _mm512_fmadd_ps
1921
#define v_muladd_f64 _mm512_fmadd_pd
22+
// multiply and subtract, a*b - c
23+
#define v_mulsub_f32 _mm512_fmsub_ps
24+
#define v_mulsub_f64 _mm512_fmsub_pd
2025
BLAS_FINLINE float v_sum_f32(v_f32 a)
2126
{
2227
__m512 h64 = _mm512_shuffle_f32x4(a, a, _MM_SHUFFLE(3, 2, 3, 2));

kernel/simd/intrin_neon.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ typedef float32x4_t v_f32;
1818
***************************/
1919
#define v_add_f32 vaddq_f32
2020
#define v_add_f64 vaddq_f64
21+
#define v_sub_f32 vsubq_f32
22+
#define v_sub_f64 vsubq_f64
2123
#define v_mul_f32 vmulq_f32
2224
#define v_mul_f64 vmulq_f64
2325

@@ -26,16 +28,24 @@ typedef float32x4_t v_f32;
2628
// multiply and add, a*b + c
2729
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
2830
{ return vfmaq_f32(c, a, b); }
31+
// multiply and subtract, a*b - c
32+
BLAS_FINLINE v_f32 v_mulsub_f32(v_f32 a, v_f32 b, v_f32 c)
33+
{ return vfmaq_f32(vnegq_f32(c), a, b); }
2934
#else
3035
// multiply and add, a*b + c
3136
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
3237
{ return vmlaq_f32(c, a, b); }
38+
// multiply and subtract, a*b - c
39+
BLAS_FINLINE v_f32 v_mulsub_f32(v_f32 a, v_f32 b, v_f32 c)
40+
{ return vmlaq_f32(vnegq_f32(c), a, b); }
3341
#endif
3442

3543
// FUSED F64
3644
#if V_SIMD_F64
3745
BLAS_FINLINE v_f64 v_muladd_f64(v_f64 a, v_f64 b, v_f64 c)
3846
{ return vfmaq_f64(c, a, b); }
47+
BLAS_FINLINE v_f64 v_mulsub_f64(v_f64 a, v_f64 b, v_f64 c)
48+
{ return vfmaq_f64(vnegq_f64(c), a, b); }
3949
#endif
4050

4151
// Horizontal add: Calculates the sum of all vector elements.

kernel/simd/intrin_sse.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,35 @@ typedef __m128d v_f64;
1212
***************************/
1313
#define v_add_f32 _mm_add_ps
1414
#define v_add_f64 _mm_add_pd
15+
#define v_sub_f32 _mm_sub_ps
16+
#define v_sub_f64 _mm_sub_pd
1517
#define v_mul_f32 _mm_mul_ps
1618
#define v_mul_f64 _mm_mul_pd
1719
#ifdef HAVE_FMA3
1820
// multiply and add, a*b + c
1921
#define v_muladd_f32 _mm_fmadd_ps
2022
#define v_muladd_f64 _mm_fmadd_pd
23+
// multiply and subtract, a*b - c
24+
#define v_mulsub_f32 _mm_fmsub_ps
25+
#define v_mulsub_f64 _mm_fmsub_pd
2126
#elif defined(HAVE_FMA4)
2227
// multiply and add, a*b + c
2328
#define v_muladd_f32 _mm_macc_ps
2429
#define v_muladd_f64 _mm_macc_pd
30+
// multiply and subtract, a*b - c
31+
#define v_mulsub_f32 _mm_msub_ps
32+
#define v_mulsub_f64 _mm_msub_pd
2533
#else
2634
// multiply and add, a*b + c
2735
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
2836
{ return v_add_f32(v_mul_f32(a, b), c); }
2937
BLAS_FINLINE v_f64 v_muladd_f64(v_f64 a, v_f64 b, v_f64 c)
3038
{ return v_add_f64(v_mul_f64(a, b), c); }
39+
// multiply and subtract, a*b - c
40+
BLAS_FINLINE v_f32 v_mulsub_f32(v_f32 a, v_f32 b, v_f32 c)
41+
{ return v_sub_f32(v_mul_f32(a, b), c); }
42+
BLAS_FINLINE v_f64 v_mulsub_f64(v_f64 a, v_f64 b, v_f64 c)
43+
{ return v_sub_f64(v_mul_f64(a, b), c); }
3144
#endif // HAVE_FMA3
3245

3346
// Horizontal add: Calculates the sum of all vector elements.

kernel/x86_64/drot.c

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,76 @@
77
#endif
88

99
#ifndef HAVE_DROT_KERNEL
10+
#include "../simd/intrin.h"
1011

1112
static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
1213
{
1314
BLASLONG i = 0;
15+
#if V_SIMD_F64 && V_SIMD > 256
16+
const int vstep = v_nlanes_f64;
17+
const int unrollx4 = n & (-vstep * 4);
18+
const int unrollx = n & -vstep;
19+
20+
v_f64 __c = v_setall_f64(c);
21+
v_f64 __s = v_setall_f64(s);
22+
v_f64 vx0, vx1, vx2, vx3;
23+
v_f64 vy0, vy1, vy2, vy3;
24+
v_f64 vt0, vt1, vt2, vt3;
25+
26+
for (; i < unrollx4; i += vstep * 4) {
27+
vx0 = v_loadu_f64(x + i);
28+
vx1 = v_loadu_f64(x + i + vstep);
29+
vx2 = v_loadu_f64(x + i + vstep * 2);
30+
vx3 = v_loadu_f64(x + i + vstep * 3);
31+
vy0 = v_loadu_f64(y + i);
32+
vy1 = v_loadu_f64(y + i + vstep);
33+
vy2 = v_loadu_f64(y + i + vstep * 2);
34+
vy3 = v_loadu_f64(y + i + vstep * 3);
35+
36+
vt0 = v_mul_f64(__s, vy0);
37+
vt1 = v_mul_f64(__s, vy1);
38+
vt2 = v_mul_f64(__s, vy2);
39+
vt3 = v_mul_f64(__s, vy3);
40+
41+
vt0 = v_muladd_f64(__c, vx0, vt0);
42+
vt1 = v_muladd_f64(__c, vx1, vt1);
43+
vt2 = v_muladd_f64(__c, vx2, vt2);
44+
vt3 = v_muladd_f64(__c, vx3, vt3);
45+
46+
v_storeu_f64(x + i, vt0);
47+
v_storeu_f64(x + i + vstep, vt1);
48+
v_storeu_f64(x + i + vstep * 2, vt2);
49+
v_storeu_f64(x + i + vstep * 3, vt3);
50+
51+
vt0 = v_mul_f64(__s, vx0);
52+
vt1 = v_mul_f64(__s, vx1);
53+
vt2 = v_mul_f64(__s, vx2);
54+
vt3 = v_mul_f64(__s, vx3);
55+
56+
vt0 = v_mulsub_f64(__c, vy0, vt0);
57+
vt1 = v_mulsub_f64(__c, vy1, vt1);
58+
vt2 = v_mulsub_f64(__c, vy2, vt2);
59+
vt3 = v_mulsub_f64(__c, vy3, vt3);
60+
61+
v_storeu_f64(y + i, vt0);
62+
v_storeu_f64(y + i + vstep, vt1);
63+
v_storeu_f64(y + i + vstep * 2, vt2);
64+
v_storeu_f64(y + i + vstep * 3, vt3);
65+
}
66+
67+
for (; i < unrollx; i += vstep) {
68+
vx0 = v_loadu_f64(x + i);
69+
vy0 = v_loadu_f64(y + i);
70+
71+
vt0 = v_mul_f64(__s, vy0);
72+
vt0 = v_muladd_f64(__c, vx0, vt0);
73+
v_storeu_f64(x + i, vt0);
74+
75+
vt0 = v_mul_f64(__s, vx0);
76+
vt0 = v_mulsub_f64(__c, vy0, vt0);
77+
v_storeu_f64(y + i, vt0);
78+
}
79+
#else
1480
FLOAT f0, f1, f2, f3;
1581
FLOAT x0, x1, x2, x3;
1682
FLOAT g0, g1, g2, g3;
@@ -53,7 +119,7 @@ static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
53119
yp += 4;
54120
i += 4;
55121
}
56-
122+
#endif
57123
while (i < n) {
58124
FLOAT temp = c*x[i] + s*y[i];
59125
y[i] = c*y[i] - s*x[i];

kernel/x86_64/srot.c

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,78 @@
77
#endif
88

99
#ifndef HAVE_SROT_KERNEL
10+
#include"../simd/intrin.h"
1011

1112
static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
1213
{
1314
BLASLONG i = 0;
15+
16+
#if V_SIMD && (defined(HAVE_FMA3) || V_SIMD > 128)
17+
const int vstep = v_nlanes_f32;
18+
const int unrollx4 = n & (-vstep * 4);
19+
const int unrollx = n & -vstep;
20+
21+
v_f32 __c = v_setall_f32(c);
22+
v_f32 __s = v_setall_f32(s);
23+
v_f32 vx0, vx1, vx2, vx3;
24+
v_f32 vy0, vy1, vy2, vy3;
25+
v_f32 vt0, vt1, vt2, vt3;
26+
27+
for (; i < unrollx4; i += vstep * 4) {
28+
vx0 = v_loadu_f32(x + i);
29+
vx1 = v_loadu_f32(x + i + vstep);
30+
vx2 = v_loadu_f32(x + i + vstep * 2);
31+
vx3 = v_loadu_f32(x + i + vstep * 3);
32+
vy0 = v_loadu_f32(y + i);
33+
vy1 = v_loadu_f32(y + i + vstep);
34+
vy2 = v_loadu_f32(y + i + vstep * 2);
35+
vy3 = v_loadu_f32(y + i + vstep * 3);
36+
37+
vt0 = v_mul_f32(__s, vy0);
38+
vt1 = v_mul_f32(__s, vy1);
39+
vt2 = v_mul_f32(__s, vy2);
40+
vt3 = v_mul_f32(__s, vy3);
41+
42+
vt0 = v_muladd_f32(__c, vx0, vt0);
43+
vt1 = v_muladd_f32(__c, vx1, vt1);
44+
vt2 = v_muladd_f32(__c, vx2, vt2);
45+
vt3 = v_muladd_f32(__c, vx3, vt3);
46+
47+
v_storeu_f32(x + i, vt0);
48+
v_storeu_f32(x + i + vstep, vt1);
49+
v_storeu_f32(x + i + vstep * 2, vt2);
50+
v_storeu_f32(x + i + vstep * 3, vt3);
51+
52+
vt0 = v_mul_f32(__s, vx0);
53+
vt1 = v_mul_f32(__s, vx1);
54+
vt2 = v_mul_f32(__s, vx2);
55+
vt3 = v_mul_f32(__s, vx3);
56+
57+
vt0 = v_mulsub_f32(__c, vy0, vt0);
58+
vt1 = v_mulsub_f32(__c, vy1, vt1);
59+
vt2 = v_mulsub_f32(__c, vy2, vt2);
60+
vt3 = v_mulsub_f32(__c, vy3, vt3);
61+
62+
v_storeu_f32(y + i, vt0);
63+
v_storeu_f32(y + i + vstep, vt1);
64+
v_storeu_f32(y + i + vstep * 2, vt2);
65+
v_storeu_f32(y + i + vstep * 3, vt3);
66+
67+
}
68+
69+
for (; i < unrollx; i += vstep) {
70+
vx0 = v_loadu_f32(x + i);
71+
vy0 = v_loadu_f32(y + i);
72+
73+
vt0 = v_mul_f32(__s, vy0);
74+
vt0 = v_muladd_f32(__c, vx0, vt0);
75+
v_storeu_f32(x + i, vt0);
76+
77+
vt0 = v_mul_f32(__s, vx0);
78+
vt0 = v_mulsub_f32(__c, vy0, vt0);
79+
v_storeu_f32(y + i, vt0);
80+
}
81+
#else
1482
FLOAT f0, f1, f2, f3;
1583
FLOAT x0, x1, x2, x3;
1684
FLOAT g0, g1, g2, g3;
@@ -20,7 +88,6 @@ static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
2088
FLOAT* yp = y;
2189

2290
BLASLONG n1 = n & (~7);
23-
2491
while (i < n1) {
2592
x0 = xp[0];
2693
y0 = yp[0];
@@ -53,6 +120,7 @@ static void srot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
53120
yp += 4;
54121
i += 4;
55122
}
123+
#endif
56124

57125
while (i < n) {
58126
FLOAT temp = c*x[i] + s*y[i];

0 commit comments

Comments
 (0)