Skip to content

Commit a85ac71

Browse files
authored
Merge pull request #100 from xianyi/develop
rebase
2 parents a897bc3 + 4c25910 commit a85ac71

File tree

8 files changed

+125
-5
lines changed

8 files changed

+125
-5
lines changed

Makefile.x86_64

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ ifdef HAVE_SSSE3
1616
CCOMMON_OPT += -mssse3
1717
FCOMMON_OPT += -mssse3
1818
endif
19+
ifdef HAVE_SSE4_1
20+
CCOMMON_OPT += -msse4.1
21+
FCOMMON_OPT += -msse4.1
22+
endif
1923
endif
2024
endif
2125

kernel/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ endif
4545

4646
ifdef TARGET_CORE
4747
ifeq ($(TARGET_CORE), $(filter $(TARGET_CORE),PRESCOTT CORE2 PENRYN DUNNINGTON ATOM NANO SANDYBRIDGE HASWELL NEHALEM ZEN BARCELONA BOBCAT BULLDOZER PILEDRIVER EXCAVATOR STEAMROLLER OPTERON_SSE3))
48-
override CFLAGS += -msse3 -mssse3
48+
override CFLAGS += -msse3 -mssse3 -msse4.1
4949
endif
5050
ifeq ($(TARGET_CORE), COOPERLAKE)
5151
override CFLAGS += -DBUILD_KERNEL -DTABLE_NAME=gotoblas_$(TARGET_CORE)

kernel/arm/sum.c

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,26 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
4343
if (inc_x == 1)
4444
{
4545
#if V_SIMD
46+
#ifdef DOUBLE
47+
const int vstep = v_nlanes_f64;
48+
const int unrollx2 = n & (-vstep * 2);
49+
const int unrollx = n & -vstep;
50+
v_f64 vsum0 = v_zero_f64();
51+
v_f64 vsum1 = v_zero_f64();
52+
while (i < unrollx2)
53+
{
54+
vsum0 = v_add_f64(vsum0, v_loadu_f64(x));
55+
vsum1 = v_add_f64(vsum1, v_loadu_f64(x + vstep));
56+
i += vstep * 2;
57+
}
58+
vsum0 = v_add_f64(vsum0, vsum1);
59+
while (i < unrollx)
60+
{
61+
vsum0 = v_add_f64(vsum0, v_loadu_f64(x + i));
62+
i += vstep;
63+
}
64+
sumf = v_sum_f64(vsum0);
65+
#else
4666
const int vstep = v_nlanes_f32;
4767
const int unrollx4 = n & (-vstep * 4);
4868
const int unrollx = n & -vstep;
@@ -66,6 +86,7 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
6686
i += vstep;
6787
}
6888
sumf = v_sum_f32(vsum0);
89+
#endif
6990
#else
7091
int n1 = n & -4;
7192
for (; i < n1; i += 4)

kernel/simd/intrin_avx.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,27 @@
44
* Data Type
55
***************************/
66
typedef __m256 v_f32;
7+
typedef __m256d v_f64;
78
#define v_nlanes_f32 8
9+
#define v_nlanes_f64 4
810
/***************************
911
* Arithmetic
1012
***************************/
1113
#define v_add_f32 _mm256_add_ps
14+
#define v_add_f64 _mm256_add_pd
1215
#define v_mul_f32 _mm256_mul_ps
16+
#define v_mul_f64 _mm256_mul_pd
1317

1418
#ifdef HAVE_FMA3
1519
// multiply and add, a*b + c
1620
#define v_muladd_f32 _mm256_fmadd_ps
21+
#define v_muladd_f64 _mm256_fmadd_pd
1722
#else
1823
// multiply and add, a*b + c
1924
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
2025
{ return v_add_f32(v_mul_f32(a, b), c); }
26+
BLAS_FINLINE v_f64 v_muladd_f64(v_f64 a, v_f64 b, v_f64 c)
27+
{ return v_add_f64(v_mul_f64(a, b), c); }
2128
#endif // !HAVE_FMA3
2229

2330
// Horizontal add: Calculates the sum of all vector elements.
@@ -31,11 +38,23 @@ BLAS_FINLINE float v_sum_f32(__m256 a)
3138
return _mm_cvtss_f32(sum);
3239
}
3340

41+
BLAS_FINLINE double v_sum_f64(__m256d a)
42+
{
43+
__m256d sum_halves = _mm256_hadd_pd(a, a);
44+
__m128d lo = _mm256_castpd256_pd128(sum_halves);
45+
__m128d hi = _mm256_extractf128_pd(sum_halves, 1);
46+
__m128d sum = _mm_add_pd(lo, hi);
47+
return _mm_cvtsd_f64(sum);
48+
}
3449
/***************************
3550
* memory
3651
***************************/
3752
// unaligned load
3853
#define v_loadu_f32 _mm256_loadu_ps
54+
#define v_loadu_f64 _mm256_loadu_pd
3955
#define v_storeu_f32 _mm256_storeu_ps
56+
#define v_storeu_f64 _mm256_storeu_pd
4057
#define v_setall_f32(VAL) _mm256_set1_ps(VAL)
41-
#define v_zero_f32 _mm256_setzero_ps
58+
#define v_setall_f64(VAL) _mm256_set1_pd(VAL)
59+
#define v_zero_f32 _mm256_setzero_ps
60+
#define v_zero_f64 _mm256_setzero_pd

kernel/simd/intrin_avx512.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,19 @@
44
* Data Type
55
***************************/
66
typedef __m512 v_f32;
7+
typedef __m512d v_f64;
78
#define v_nlanes_f32 16
9+
#define v_nlanes_f64 8
810
/***************************
911
* Arithmetic
1012
***************************/
1113
#define v_add_f32 _mm512_add_ps
14+
#define v_add_f64 _mm512_add_pd
1215
#define v_mul_f32 _mm512_mul_ps
16+
#define v_mul_f64 _mm512_mul_pd
1317
// multiply and add, a*b + c
1418
#define v_muladd_f32 _mm512_fmadd_ps
15-
19+
#define v_muladd_f64 _mm512_fmadd_pd
1620
BLAS_FINLINE float v_sum_f32(v_f32 a)
1721
{
1822
__m512 h64 = _mm512_shuffle_f32x4(a, a, _MM_SHUFFLE(3, 2, 3, 2));
@@ -25,11 +29,26 @@ BLAS_FINLINE float v_sum_f32(v_f32 a)
2529
__m512 sum4 = _mm512_add_ps(sum8, h4);
2630
return _mm_cvtss_f32(_mm512_castps512_ps128(sum4));
2731
}
32+
33+
BLAS_FINLINE double v_sum_f64(v_f64 a)
34+
{
35+
__m512d h64 = _mm512_shuffle_f64x2(a, a, _MM_SHUFFLE(3, 2, 3, 2));
36+
__m512d sum32 = _mm512_add_pd(a, h64);
37+
__m512d h32 = _mm512_permutex_pd(sum32, _MM_SHUFFLE(1, 0, 3, 2));
38+
__m512d sum16 = _mm512_add_pd(sum32, h32);
39+
__m512d h16 = _mm512_permute_pd(sum16, _MM_SHUFFLE(2, 3, 0, 1));
40+
__m512d sum8 = _mm512_add_pd(sum16, h16);
41+
return _mm_cvtsd_f64(_mm512_castpd512_pd128(sum8));
42+
}
2843
/***************************
2944
* memory
3045
***************************/
3146
// unaligned load
3247
#define v_loadu_f32(PTR) _mm512_loadu_ps((const __m512*)(PTR))
48+
#define v_loadu_f64(PTR) _mm512_loadu_pd((const __m512*)(PTR))
3349
#define v_storeu_f32 _mm512_storeu_ps
50+
#define v_storeu_f64 _mm512_storeu_pd
3451
#define v_setall_f32(VAL) _mm512_set1_ps(VAL)
52+
#define v_setall_f64(VAL) _mm512_set1_pd(VAL)
3553
#define v_zero_f32 _mm512_setzero_ps
54+
#define v_zero_f64 _mm512_setzero_pd

kernel/simd/intrin_neon.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,18 @@
88
* Data Type
99
***************************/
1010
typedef float32x4_t v_f32;
11+
#if V_SIMD_F64
12+
typedef float64x2_t v_f64;
13+
#endif
1114
#define v_nlanes_f32 4
15+
#define v_nlanes_f64 2
1216
/***************************
1317
* Arithmetic
1418
***************************/
1519
#define v_add_f32 vaddq_f32
20+
#define v_add_f64 vaddq_f64
1621
#define v_mul_f32 vmulq_f32
22+
#define v_mul_f64 vmulq_f64
1723

1824
// FUSED F32
1925
#ifdef HAVE_VFPV4 // FMA
@@ -26,17 +32,37 @@ typedef float32x4_t v_f32;
2632
{ return vmlaq_f32(c, a, b); }
2733
#endif
2834

35+
// FUSED F64
36+
#if V_SIMD_F64
37+
BLAS_FINLINE v_f64 v_muladd_f64(v_f64 a, v_f64 b, v_f64 c)
38+
{ return vfmaq_f64(c, a, b); }
39+
#endif
40+
2941
// Horizontal add: Calculates the sum of all vector elements.
3042
BLAS_FINLINE float v_sum_f32(float32x4_t a)
3143
{
3244
float32x2_t r = vadd_f32(vget_high_f32(a), vget_low_f32(a));
3345
return vget_lane_f32(vpadd_f32(r, r), 0);
3446
}
47+
48+
#if V_SIMD_F64
49+
BLAS_FINLINE double v_sum_f64(float64x2_t a)
50+
{
51+
return vget_lane_f64(vget_low_f64(a) + vget_high_f64(a), 0);
52+
}
53+
#endif
54+
3555
/***************************
3656
* memory
3757
***************************/
3858
// unaligned load
3959
#define v_loadu_f32(a) vld1q_f32((const float*)a)
4060
#define v_storeu_f32 vst1q_f32
4161
#define v_setall_f32(VAL) vdupq_n_f32(VAL)
42-
#define v_zero_f32() vdupq_n_f32(0.0f)
62+
#define v_zero_f32() vdupq_n_f32(0.0f)
63+
#if V_SIMD_F64
64+
#define v_loadu_f64(a) vld1q_f64((const double*)a)
65+
#define v_storeu_f64 vst1q_f64
66+
#define v_setall_f64 vdupq_n_f64
67+
#define v_zero_f64() vdupq_n_f64(0.0)
68+
#endif

kernel/simd/intrin_sse.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,30 @@
44
* Data Type
55
***************************/
66
typedef __m128 v_f32;
7+
typedef __m128d v_f64;
78
#define v_nlanes_f32 4
9+
#define v_nlanes_f64 2
810
/***************************
911
* Arithmetic
1012
***************************/
1113
#define v_add_f32 _mm_add_ps
14+
#define v_add_f64 _mm_add_pd
1215
#define v_mul_f32 _mm_mul_ps
16+
#define v_mul_f64 _mm_mul_pd
1317
#ifdef HAVE_FMA3
1418
// multiply and add, a*b + c
1519
#define v_muladd_f32 _mm_fmadd_ps
20+
#define v_muladd_f64 _mm_fmadd_pd
1621
#elif defined(HAVE_FMA4)
1722
// multiply and add, a*b + c
1823
#define v_muladd_f32 _mm_macc_ps
24+
#define v_muladd_f64 _mm_macc_pd
1925
#else
2026
// multiply and add, a*b + c
2127
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
2228
{ return v_add_f32(v_mul_f32(a, b), c); }
29+
BLAS_FINLINE v_f64 v_muladd_f64(v_f64 a, v_f64 b, v_f64 c)
30+
{ return v_add_f64(v_mul_f64(a, b), c); }
2331
#endif // HAVE_FMA3
2432

2533
// Horizontal add: Calculates the sum of all vector elements.
@@ -36,11 +44,24 @@ BLAS_FINLINE float v_sum_f32(__m128 a)
3644
return _mm_cvtss_f32(t4);
3745
#endif
3846
}
47+
48+
BLAS_FINLINE double v_sum_f64(__m128d a)
49+
{
50+
#ifdef HAVE_SSE3
51+
return _mm_cvtsd_f64(_mm_hadd_pd(a, a));
52+
#else
53+
return _mm_cvtsd_f64(_mm_add_pd(a, _mm_unpackhi_pd(a, a)));
54+
#endif
55+
}
3956
/***************************
4057
* memory
4158
***************************/
4259
// unaligned load
4360
#define v_loadu_f32 _mm_loadu_ps
61+
#define v_loadu_f64 _mm_loadu_pd
4462
#define v_storeu_f32 _mm_storeu_ps
63+
#define v_storeu_f64 _mm_storeu_pd
4564
#define v_setall_f32(VAL) _mm_set1_ps(VAL)
46-
#define v_zero_f32 _mm_setzero_ps
65+
#define v_setall_f64(VAL) _mm_set1_pd(VAL)
66+
#define v_zero_f32 _mm_setzero_ps
67+
#define v_zero_f64 _mm_setzero_pd

kernel/x86_64/daxpy.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,23 @@ static void daxpy_kernel_8(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT *alpha)
5353
BLASLONG register i = 0;
5454
FLOAT a = *alpha;
5555
#if V_SIMD
56+
#ifdef DOUBLE
57+
v_f64 __alpha, tmp;
58+
__alpha = v_setall_f64(*alpha);
59+
const int vstep = v_nlanes_f64;
60+
for (; i < n; i += vstep) {
61+
tmp = v_muladd_f64(__alpha, v_loadu_f64( x + i ), v_loadu_f64(y + i));
62+
v_storeu_f64(y + i, tmp);
63+
}
64+
#else
5665
v_f32 __alpha, tmp;
5766
__alpha = v_setall_f32(*alpha);
5867
const int vstep = v_nlanes_f32;
5968
for (; i < n; i += vstep) {
6069
tmp = v_muladd_f32(__alpha, v_loadu_f32( x + i ), v_loadu_f32(y + i));
6170
v_storeu_f32(y + i, tmp);
6271
}
72+
#endif
6373
#else
6474
while(i < n)
6575
{

0 commit comments

Comments
 (0)