Skip to content

Commit 1b0f17e

Browse files
committed
align to 64, using SSE when input size is small
1 parent 448152c commit 1b0f17e

File tree

6 files changed

+392
-224
lines changed

6 files changed

+392
-224
lines changed

kernel/x86_64/dasum.c

Lines changed: 63 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,82 @@
11
#include "common.h"
2-
#include <math.h>
32

4-
#define ABS fabs
3+
#ifndef ABS_K
4+
#define ABS_K(a) ((a) > 0 ? (a) : (-(a)))
5+
#endif
56

67
#if defined(SKYLAKEX)
78
#include "dasum_microk_skylakex-2.c"
89
#elif defined(HASWELL)
910
#include "dasum_microk_haswell-2.c"
1011
#endif
1112

12-
#ifndef HAVE_KERNEL_16
13-
static FLOAT dasum_kernel_16(BLASLONG n, FLOAT *x1)
13+
#ifndef HAVE_DASUM_KERNEL
14+
static FLOAT dasum_kernel(BLASLONG n, FLOAT *x1)
1415
{
1516

16-
BLASLONG i=0;
17-
FLOAT *x = x1;
18-
FLOAT temp0, temp1, temp2, temp3;
19-
FLOAT temp4, temp5, temp6, temp7;
20-
FLOAT sum0 = 0.0;
21-
FLOAT sum1 = 0.0;
22-
FLOAT sum2 = 0.0;
23-
FLOAT sum3 = 0.0;
24-
25-
while ( i< n )
26-
{
27-
28-
temp0 = ABS(x[0]);
29-
temp1 = ABS(x[1]);
30-
temp2 = ABS(x[2]);
31-
temp3 = ABS(x[3]);
32-
temp4 = ABS(x[4]);
33-
temp5 = ABS(x[5]);
34-
temp6 = ABS(x[6]);
35-
temp7 = ABS(x[7]);
36-
37-
sum0 += temp0;
38-
sum1 += temp1;
39-
sum2 += temp2;
40-
sum3 += temp3;
41-
42-
sum0 += temp4;
43-
sum1 += temp5;
44-
sum2 += temp6;
45-
sum3 += temp7;
46-
47-
x+=8;
48-
i+=8;
49-
50-
}
51-
52-
return sum0+sum1+sum2+sum3;
17+
BLASLONG i=0;
18+
BLASLONG n_8 = n & -8;
19+
FLOAT *x = x1;
20+
FLOAT temp0, temp1, temp2, temp3;
21+
FLOAT temp4, temp5, temp6, temp7;
22+
FLOAT sum0 = 0.0;
23+
FLOAT sum1 = 0.0;
24+
FLOAT sum2 = 0.0;
25+
FLOAT sum3 = 0.0;
26+
FLOAT sum4 = 0.0;
27+
28+
while (i < n_8) {
29+
temp0 = ABS_K(x[0]);
30+
temp1 = ABS_K(x[1]);
31+
temp2 = ABS_K(x[2]);
32+
temp3 = ABS_K(x[3]);
33+
temp4 = ABS_K(x[4]);
34+
temp5 = ABS_K(x[5]);
35+
temp6 = ABS_K(x[6]);
36+
temp7 = ABS_K(x[7]);
37+
38+
sum0 += temp0;
39+
sum1 += temp1;
40+
sum2 += temp2;
41+
sum3 += temp3;
42+
43+
sum0 += temp4;
44+
sum1 += temp5;
45+
sum2 += temp6;
46+
sum3 += temp7;
47+
48+
x+=8;
49+
i+=8;
50+
}
51+
52+
while (i < n) {
53+
sum4 += ABS_K(x1[i]);
54+
i++;
55+
}
56+
57+
return sum0+sum1+sum2+sum3+sum4;
5358
}
5459

5560
#endif
5661

5762
FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
5863
{
59-
BLASLONG i=0;
60-
FLOAT sumf = 0.0;
61-
BLASLONG n1;
62-
63-
if (n <= 0 || inc_x <= 0) return(sumf);
64-
65-
if ( inc_x == 1 )
66-
{
67-
68-
n1 = n & -16;
69-
if ( n1 > 0 )
70-
{
71-
72-
sumf = dasum_kernel_16(n1, x);
73-
i=n1;
74-
}
75-
76-
while(i < n)
77-
{
78-
sumf += ABS(x[i]);
79-
i++;
80-
}
81-
82-
}
83-
else
84-
{
85-
86-
n *= inc_x;
87-
while(i < n)
88-
{
89-
sumf += ABS(x[i]);
90-
i += inc_x;
91-
}
92-
93-
}
94-
return(sumf);
64+
BLASLONG i=0;
65+
FLOAT sumf = 0.0;
66+
67+
if (n <= 0 || inc_x <= 0) return(sumf);
68+
69+
if ( inc_x == 1 ) {
70+
sumf = dasum_kernel(n, x);
71+
}
72+
else {
73+
n *= inc_x;
74+
75+
while(i < n) {
76+
sumf += ABS_K(x[i]);
77+
i += inc_x;
78+
}
79+
}
80+
return(sumf);
9581
}
9682

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,86 @@
11
#if (( defined(__GNUC__) && __GNUC__ > 6 ) || (defined(__clang__) && __clang_major__ >= 6)) && defined(__AVX2__)
22

3-
#define HAVE_KERNEL_16 1
3+
#define HAVE_DASUM_KERNEL
44

55
#include <immintrin.h>
6+
#include <stdint.h>
67

7-
static FLOAT dasum_kernel_16(BLASLONG n, FLOAT *x1)
8+
#ifndef ABS_K
9+
#define ABS_K(a) ((a) > 0 ? (a) : (-(a)))
10+
#endif
11+
12+
static FLOAT dasum_kernel(BLASLONG n, FLOAT *x1)
813
{
914
BLASLONG i = 0;
10-
__m256d accum_0, accum_1, accum_2, accum_3;
11-
12-
accum_0 = _mm256_setzero_pd();
13-
accum_1 = _mm256_setzero_pd();
14-
accum_2 = _mm256_setzero_pd();
15-
accum_3 = _mm256_setzero_pd();
16-
17-
__m256i abs_mask = _mm256_set1_epi64x(0x7fffffffffffffff);
18-
for (; i < n; i += 16) {
19-
accum_0 += (__m256d)_mm256_and_si256(_mm256_loadu_si256(&x1[i+ 0]), abs_mask);
20-
accum_1 += (__m256d)_mm256_and_si256(_mm256_loadu_si256(&x1[i+ 4]), abs_mask);
21-
accum_2 += (__m256d)_mm256_and_si256(_mm256_loadu_si256(&x1[i+ 8]), abs_mask);
22-
accum_3 += (__m256d)_mm256_and_si256(_mm256_loadu_si256(&x1[i+12]), abs_mask);
15+
FLOAT sumf = 0.0;
16+
17+
if (n >= 256) {
18+
BLASLONG align_256 = ((32 - ((uintptr_t)x1 & (uintptr_t)0x1f)) >> 3) & 0x3;
19+
20+
for (i = 0; i < align_256; i++) {
21+
sumf += ABS_K(x1[i]);
22+
}
23+
24+
n -= align_256;
25+
x1 += align_256;
26+
}
27+
28+
BLASLONG tail_index_SSE = n&(~7);
29+
BLASLONG tail_index_AVX2 = n&(~255);
30+
31+
if (n >= 256) {
32+
__m256d accum_0, accum_1, accum_2, accum_3;
33+
34+
accum_0 = _mm256_setzero_pd();
35+
accum_1 = _mm256_setzero_pd();
36+
accum_2 = _mm256_setzero_pd();
37+
accum_3 = _mm256_setzero_pd();
38+
39+
__m256i abs_mask = _mm256_set1_epi64x(0x7fffffffffffffff);
40+
for (i = 0; i < tail_index_AVX2; i += 16) {
41+
accum_0 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+ 0]), abs_mask);
42+
accum_1 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+ 4]), abs_mask);
43+
accum_2 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+ 8]), abs_mask);
44+
accum_3 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+12]), abs_mask);
45+
}
46+
47+
accum_0 = accum_0 + accum_1 + accum_2 + accum_3;
48+
49+
__m128d half_accum0;
50+
half_accum0 = _mm_add_pd(_mm256_extractf128_pd(accum_0, 0), _mm256_extractf128_pd(accum_0, 1));
51+
52+
half_accum0 = _mm_hadd_pd(half_accum0, half_accum0);
53+
54+
sumf += half_accum0[0];
2355
}
56+
57+
if (n >= 8) {
58+
__m128d accum_20, accum_21, accum_22, accum_23;
59+
accum_20 = _mm_setzero_pd();
60+
accum_21 = _mm_setzero_pd();
61+
accum_22 = _mm_setzero_pd();
62+
accum_23 = _mm_setzero_pd();
2463

25-
accum_0 = accum_0 + accum_1 + accum_2 + accum_3;
64+
__m128i abs_mask2 = _mm_set1_epi64x(0x7fffffffffffffff);
65+
for (i = tail_index_AVX2; i < tail_index_SSE; i += 8) {
66+
accum_20 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 0]), abs_mask2);
67+
accum_21 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 2]), abs_mask2);
68+
accum_22 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 4]), abs_mask2);
69+
accum_23 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 6]), abs_mask2);
70+
}
2671

27-
__m128d half_accum0;
28-
half_accum0 = _mm_add_pd(_mm256_extractf128_pd(accum_0, 0), _mm256_extractf128_pd(accum_0, 1));
72+
accum_20 = accum_20 + accum_21 + accum_22 + accum_23;
73+
__m128d half_accum20;
74+
half_accum20 = _mm_hadd_pd(accum_20, accum_20);
2975

30-
half_accum0 = _mm_hadd_pd(half_accum0, half_accum0);
76+
sumf += half_accum20[0];
77+
}
78+
79+
for (i = tail_index_SSE; i < n; ++i) {
80+
sumf += ABS_K(x1[i]);
81+
}
3182

32-
return half_accum0[0];
83+
return sumf;
3384

3485
}
3586
#endif

kernel/x86_64/dasum_microk_skylakex-2.c

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,80 @@
11
/* need a new enough GCC for avx512 support */
2-
#if (( defined(__GNUC__) && __GNUC__ > 6 && defined(__AVX2__)) || (defined(__clang__) && __clang_major__ >= 6))
2+
#if (( defined(__GNUC__) && __GNUC__ > 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 9))
33

4-
#if defined(__AVX512CD__)
5-
#define HAVE_KERNEL_16 1
4+
#define HAVE_DASUM_KERNEL 1
65

76
#include <immintrin.h>
87

9-
static FLOAT dasum_kernel_16(BLASLONG n, FLOAT *x1)
8+
#include <stdint.h>
9+
10+
#ifndef ABS_K
11+
#define ABS_K(a) ((a) > 0 ? (a) : (-(a)))
12+
#endif
13+
14+
static FLOAT dasum_kernel(BLASLONG n, FLOAT *x1)
1015
{
1116
BLASLONG i = 0;
17+
FLOAT sumf = 0.0;
18+
19+
if (n >= 256) {
20+
BLASLONG align_512 = ((64 - ((uintptr_t)x1 & (uintptr_t)0x3f)) >> 3) & 0x7;
1221

13-
__m512d accum_0, accum_1;
22+
for (i = 0; i < align_512; i++) {
23+
sumf += ABS_K(x1[i]);
24+
}
25+
26+
n -= align_512;
27+
x1 += align_512;
28+
}
29+
30+
BLASLONG tail_index_SSE = n&(~7);
31+
BLASLONG tail_index_AVX512 = n&(~255);
1432

15-
accum_0 = _mm512_setzero_pd();
16-
accum_1 = _mm512_setzero_pd();
33+
//
34+
if ( n >= 256 ) {
1735

18-
for (; i < n; i += 16) {
19-
accum_0 += _mm512_abs_pd(_mm512_loadu_pd(&x1[i+ 0]));
20-
accum_1 += _mm512_abs_pd(_mm512_loadu_pd(&x1[i+ 8]));
36+
__m512d accum_0, accum_1, accum_2, accum_3;
37+
accum_0 = _mm512_setzero_pd();
38+
accum_1 = _mm512_setzero_pd();
39+
accum_2 = _mm512_setzero_pd();
40+
accum_3 = _mm512_setzero_pd();
41+
for (i = 0; i < tail_index_AVX512; i += 32) {
42+
accum_0 += _mm512_abs_pd(_mm512_load_pd(&x1[i + 0]));
43+
accum_1 += _mm512_abs_pd(_mm512_load_pd(&x1[i + 8]));
44+
accum_2 += _mm512_abs_pd(_mm512_load_pd(&x1[i +16]));
45+
accum_3 += _mm512_abs_pd(_mm512_load_pd(&x1[i +24]));
46+
}
47+
48+
accum_0 = accum_0 + accum_1 + accum_2 + accum_3;
49+
sumf += _mm512_reduce_add_pd(accum_0);
2150
}
2251

23-
accum_0 += accum_1;
24-
return _mm512_reduce_add_pd(accum_0);
52+
if (n >= 8) {
53+
__m128d accum_20, accum_21, accum_22, accum_23;
54+
accum_20 = _mm_setzero_pd();
55+
accum_21 = _mm_setzero_pd();
56+
accum_22 = _mm_setzero_pd();
57+
accum_23 = _mm_setzero_pd();
58+
59+
__m128i abs_mask2 = _mm_set1_epi64x(0x7fffffffffffffff);
60+
for (i = tail_index_AVX512; i < tail_index_SSE; i += 8) {
61+
accum_20 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 0]), abs_mask2);
62+
accum_21 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 2]), abs_mask2);
63+
accum_22 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 4]), abs_mask2);
64+
accum_23 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 6]), abs_mask2);
65+
}
66+
67+
accum_20 = accum_20 + accum_21 + accum_22 + accum_23;
68+
__m128d half_accum20;
69+
half_accum20 = _mm_hadd_pd(accum_20, accum_20);
70+
71+
sumf += half_accum20[0];
72+
}
73+
74+
for (i = tail_index_SSE; i < n; ++i) {
75+
sumf += ABS_K(x1[i]);
76+
}
77+
78+
return sumf;
2579
}
2680
#endif
27-
#endif

0 commit comments

Comments
 (0)