Skip to content

Commit 7037849

Browse files
committed
Merge branch 'develop' into risc-v
2 parents c6c9c24 + 6dd71af commit 7037849

File tree

10 files changed

+218
-46
lines changed

10 files changed

+218
-46
lines changed

cmake/cc.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ if (NOT DYNAMIC_ARCH)
124124
if (HAVE_AVX)
125125
set (CCOMMON_OPT "${CCOMMON_OPT} -mavx")
126126
endif ()
127+
if (HAVE_FMA3)
128+
set (CCOMMON_OPT "${CCOMMON_OPT} -mfma")
129+
endif ()
127130
if (HAVE_SSE)
128131
set (CCOMMON_OPT "${CCOMMON_OPT} -msse")
129132
endif ()

driver/others/blas_server_omp.c

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,28 @@ static atomic_bool blas_buffer_inuse[MAX_PARALLEL_NUMBER];
7676
static _Bool blas_buffer_inuse[MAX_PARALLEL_NUMBER];
7777
#endif
7878

79-
void goto_set_num_threads(int num_threads) {
79+
static void adjust_thread_buffers() {
8080

8181
int i=0, j=0;
8282

83+
//adjust buffer for each thread
84+
for(i=0; i < MAX_PARALLEL_NUMBER; i++) {
85+
for(j=0; j < blas_cpu_number; j++){
86+
if(blas_thread_buffer[i][j] == NULL){
87+
blas_thread_buffer[i][j] = blas_memory_alloc(2);
88+
}
89+
}
90+
for(; j < MAX_CPU_NUMBER; j++){
91+
if(blas_thread_buffer[i][j] != NULL){
92+
blas_memory_free(blas_thread_buffer[i][j]);
93+
blas_thread_buffer[i][j] = NULL;
94+
}
95+
}
96+
}
97+
}
98+
99+
void goto_set_num_threads(int num_threads) {
100+
83101
if (num_threads < 1) num_threads = blas_num_threads;
84102

85103
if (num_threads > MAX_CPU_NUMBER) num_threads = MAX_CPU_NUMBER;
@@ -92,20 +110,7 @@ void goto_set_num_threads(int num_threads) {
92110

93111
omp_set_num_threads(blas_cpu_number);
94112

95-
//adjust buffer for each thread
96-
for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
97-
for(j=0; j<blas_cpu_number; j++){
98-
if(blas_thread_buffer[i][j]==NULL){
99-
blas_thread_buffer[i][j]=blas_memory_alloc(2);
100-
}
101-
}
102-
for(; j<MAX_CPU_NUMBER; j++){
103-
if(blas_thread_buffer[i][j]!=NULL){
104-
blas_memory_free(blas_thread_buffer[i][j]);
105-
blas_thread_buffer[i][j]=NULL;
106-
}
107-
}
108-
}
113+
adjust_thread_buffers();
109114
#if defined(ARCH_MIPS64)
110115
//set parameters for different number of threads.
111116
blas_set_parameter();
@@ -119,20 +124,11 @@ void openblas_set_num_threads(int num_threads) {
119124

120125
int blas_thread_init(void){
121126

122-
int i=0, j=0;
123-
124127
blas_get_cpu_number();
125128

126-
blas_server_avail = 1;
129+
adjust_thread_buffers();
127130

128-
for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
129-
for(j=0; j<blas_num_threads; j++){
130-
blas_thread_buffer[i][j]=blas_memory_alloc(2);
131-
}
132-
for(; j<MAX_CPU_NUMBER; j++){
133-
blas_thread_buffer[i][j]=NULL;
134-
}
135-
}
131+
blas_server_avail = 1;
136132

137133
return 0;
138134
}

kernel/arm/sum.c

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,27 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
4242
n *= inc_x;
4343
if (inc_x == 1)
4444
{
45-
#if V_SIMD
45+
#if V_SIMD && (!defined(DOUBLE) || (defined(DOUBLE) && V_SIMD_F64 && V_SIMD > 128))
4646
#ifdef DOUBLE
4747
const int vstep = v_nlanes_f64;
48-
const int unrollx2 = n & (-vstep * 2);
48+
const int unrollx4 = n & (-vstep * 4);
4949
const int unrollx = n & -vstep;
5050
v_f64 vsum0 = v_zero_f64();
5151
v_f64 vsum1 = v_zero_f64();
52-
while (i < unrollx2)
52+
v_f64 vsum2 = v_zero_f64();
53+
v_f64 vsum3 = v_zero_f64();
54+
for (; i < unrollx4; i += vstep * 4)
5355
{
54-
vsum0 = v_add_f64(vsum0, v_loadu_f64(x));
55-
vsum1 = v_add_f64(vsum1, v_loadu_f64(x + vstep));
56-
i += vstep * 2;
56+
vsum0 = v_add_f64(vsum0, v_loadu_f64(x + i));
57+
vsum1 = v_add_f64(vsum1, v_loadu_f64(x + i + vstep));
58+
vsum2 = v_add_f64(vsum2, v_loadu_f64(x + i + vstep * 2));
59+
vsum3 = v_add_f64(vsum3, v_loadu_f64(x + i + vstep * 3));
5760
}
58-
vsum0 = v_add_f64(vsum0, vsum1);
59-
while (i < unrollx)
61+
vsum0 = v_add_f64(
62+
v_add_f64(vsum0, vsum1), v_add_f64(vsum2, vsum3));
63+
for (; i < unrollx; i += vstep)
6064
{
6165
vsum0 = v_add_f64(vsum0, v_loadu_f64(x + i));
62-
i += vstep;
6366
}
6467
sumf = v_sum_f64(vsum0);
6568
#else
@@ -70,20 +73,18 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
7073
v_f32 vsum1 = v_zero_f32();
7174
v_f32 vsum2 = v_zero_f32();
7275
v_f32 vsum3 = v_zero_f32();
73-
while (i < unrollx4)
76+
for (; i < unrollx4; i += vstep * 4)
7477
{
75-
vsum0 = v_add_f32(vsum0, v_loadu_f32(x));
76-
vsum1 = v_add_f32(vsum1, v_loadu_f32(x + vstep));
77-
vsum2 = v_add_f32(vsum2, v_loadu_f32(x + vstep * 2));
78-
vsum3 = v_add_f32(vsum3, v_loadu_f32(x + vstep * 3));
79-
i += vstep * 4;
78+
vsum0 = v_add_f32(vsum0, v_loadu_f32(x + i));
79+
vsum1 = v_add_f32(vsum1, v_loadu_f32(x + i + vstep));
80+
vsum2 = v_add_f32(vsum2, v_loadu_f32(x + i + vstep * 2));
81+
vsum3 = v_add_f32(vsum3, v_loadu_f32(x + i + vstep * 3));
8082
}
8183
vsum0 = v_add_f32(
8284
v_add_f32(vsum0, vsum1), v_add_f32(vsum2, vsum3));
83-
while (i < unrollx)
85+
for (; i < unrollx; i += vstep)
8486
{
8587
vsum0 = v_add_f32(vsum0, v_loadu_f32(x + i));
86-
i += vstep;
8788
}
8889
sumf = v_sum_f32(vsum0);
8990
#endif

kernel/simd/intrin.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ extern "C" {
4747
#endif
4848

4949
/** AVX **/
50-
#ifdef HAVE_AVX
50+
#if defined(HAVE_AVX) || defined(HAVE_FMA3)
5151
#include <immintrin.h>
5252
#endif
5353

kernel/simd/intrin_avx.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,29 @@ typedef __m256d v_f64;
1212
***************************/
1313
#define v_add_f32 _mm256_add_ps
1414
#define v_add_f64 _mm256_add_pd
15+
#define v_sub_f32 _mm256_sub_ps
16+
#define v_sub_f64 _mm256_sub_pd
1517
#define v_mul_f32 _mm256_mul_ps
1618
#define v_mul_f64 _mm256_mul_pd
1719

1820
#ifdef HAVE_FMA3
1921
// multiply and add, a*b + c
2022
#define v_muladd_f32 _mm256_fmadd_ps
2123
#define v_muladd_f64 _mm256_fmadd_pd
24+
// multiply and subtract, a*b - c
25+
#define v_mulsub_f32 _mm256_fmsub_ps
26+
#define v_mulsub_f64 _mm256_fmsub_pd
2227
#else
2328
// multiply and add, a*b + c
2429
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
2530
{ return v_add_f32(v_mul_f32(a, b), c); }
2631
BLAS_FINLINE v_f64 v_muladd_f64(v_f64 a, v_f64 b, v_f64 c)
2732
{ return v_add_f64(v_mul_f64(a, b), c); }
33+
// multiply and subtract, a*b - c
34+
BLAS_FINLINE v_f32 v_mulsub_f32(v_f32 a, v_f32 b, v_f32 c)
35+
{ return v_sub_f32(v_mul_f32(a, b), c); }
36+
BLAS_FINLINE v_f64 v_mulsub_f64(v_f64 a, v_f64 b, v_f64 c)
37+
{ return v_sub_f64(v_mul_f64(a, b), c); }
2838
#endif // !HAVE_FMA3
2939

3040
// Horizontal add: Calculates the sum of all vector elements.

kernel/simd/intrin_avx512.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,16 @@ typedef __m512d v_f64;
1212
***************************/
1313
#define v_add_f32 _mm512_add_ps
1414
#define v_add_f64 _mm512_add_pd
15+
#define v_sub_f32 _mm512_sub_ps
16+
#define v_sub_f64 _mm512_sub_pd
1517
#define v_mul_f32 _mm512_mul_ps
1618
#define v_mul_f64 _mm512_mul_pd
1719
// multiply and add, a*b + c
1820
#define v_muladd_f32 _mm512_fmadd_ps
1921
#define v_muladd_f64 _mm512_fmadd_pd
22+
// multiply and subtract, a*b - c
23+
#define v_mulsub_f32 _mm512_fmsub_ps
24+
#define v_mulsub_f64 _mm512_fmsub_pd
2025
BLAS_FINLINE float v_sum_f32(v_f32 a)
2126
{
2227
__m512 h64 = _mm512_shuffle_f32x4(a, a, _MM_SHUFFLE(3, 2, 3, 2));

kernel/simd/intrin_neon.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ typedef float32x4_t v_f32;
1818
***************************/
1919
#define v_add_f32 vaddq_f32
2020
#define v_add_f64 vaddq_f64
21+
#define v_sub_f32 vsubq_f32
22+
#define v_sub_f64 vsubq_f64
2123
#define v_mul_f32 vmulq_f32
2224
#define v_mul_f64 vmulq_f64
2325

@@ -26,16 +28,24 @@ typedef float32x4_t v_f32;
2628
// multiply and add, a*b + c
2729
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
2830
{ return vfmaq_f32(c, a, b); }
31+
// multiply and subtract, a*b - c
32+
BLAS_FINLINE v_f32 v_mulsub_f32(v_f32 a, v_f32 b, v_f32 c)
33+
{ return vfmaq_f32(vnegq_f32(c), a, b); }
2934
#else
3035
// multiply and add, a*b + c
3136
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
3237
{ return vmlaq_f32(c, a, b); }
38+
// multiply and subtract, a*b - c
39+
BLAS_FINLINE v_f32 v_mulsub_f32(v_f32 a, v_f32 b, v_f32 c)
40+
{ return vmlaq_f32(vnegq_f32(c), a, b); }
3341
#endif
3442

3543
// FUSED F64
3644
#if V_SIMD_F64
3745
BLAS_FINLINE v_f64 v_muladd_f64(v_f64 a, v_f64 b, v_f64 c)
3846
{ return vfmaq_f64(c, a, b); }
47+
BLAS_FINLINE v_f64 v_mulsub_f64(v_f64 a, v_f64 b, v_f64 c)
48+
{ return vfmaq_f64(vnegq_f64(c), a, b); }
3949
#endif
4050

4151
// Horizontal add: Calculates the sum of all vector elements.

kernel/simd/intrin_sse.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,35 @@ typedef __m128d v_f64;
1212
***************************/
1313
#define v_add_f32 _mm_add_ps
1414
#define v_add_f64 _mm_add_pd
15+
#define v_sub_f32 _mm_sub_ps
16+
#define v_sub_f64 _mm_sub_pd
1517
#define v_mul_f32 _mm_mul_ps
1618
#define v_mul_f64 _mm_mul_pd
1719
#ifdef HAVE_FMA3
1820
// multiply and add, a*b + c
1921
#define v_muladd_f32 _mm_fmadd_ps
2022
#define v_muladd_f64 _mm_fmadd_pd
23+
// multiply and subtract, a*b - c
24+
#define v_mulsub_f32 _mm_fmsub_ps
25+
#define v_mulsub_f64 _mm_fmsub_pd
2126
#elif defined(HAVE_FMA4)
2227
// multiply and add, a*b + c
2328
#define v_muladd_f32 _mm_macc_ps
2429
#define v_muladd_f64 _mm_macc_pd
30+
// multiply and subtract, a*b - c
31+
#define v_mulsub_f32 _mm_msub_ps
32+
#define v_mulsub_f64 _mm_msub_pd
2533
#else
2634
// multiply and add, a*b + c
2735
BLAS_FINLINE v_f32 v_muladd_f32(v_f32 a, v_f32 b, v_f32 c)
2836
{ return v_add_f32(v_mul_f32(a, b), c); }
2937
BLAS_FINLINE v_f64 v_muladd_f64(v_f64 a, v_f64 b, v_f64 c)
3038
{ return v_add_f64(v_mul_f64(a, b), c); }
39+
// multiply and subtract, a*b - c
40+
BLAS_FINLINE v_f32 v_mulsub_f32(v_f32 a, v_f32 b, v_f32 c)
41+
{ return v_sub_f32(v_mul_f32(a, b), c); }
42+
BLAS_FINLINE v_f64 v_mulsub_f64(v_f64 a, v_f64 b, v_f64 c)
43+
{ return v_sub_f64(v_mul_f64(a, b), c); }
3144
#endif // HAVE_FMA3
3245

3346
// Horizontal add: Calculates the sum of all vector elements.

kernel/x86_64/drot.c

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,76 @@
77
#endif
88

99
#ifndef HAVE_DROT_KERNEL
10+
#include "../simd/intrin.h"
1011

1112
static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
1213
{
1314
BLASLONG i = 0;
15+
#if V_SIMD_F64 && V_SIMD > 256
16+
const int vstep = v_nlanes_f64;
17+
const int unrollx4 = n & (-vstep * 4);
18+
const int unrollx = n & -vstep;
19+
20+
v_f64 __c = v_setall_f64(c);
21+
v_f64 __s = v_setall_f64(s);
22+
v_f64 vx0, vx1, vx2, vx3;
23+
v_f64 vy0, vy1, vy2, vy3;
24+
v_f64 vt0, vt1, vt2, vt3;
25+
26+
for (; i < unrollx4; i += vstep * 4) {
27+
vx0 = v_loadu_f64(x + i);
28+
vx1 = v_loadu_f64(x + i + vstep);
29+
vx2 = v_loadu_f64(x + i + vstep * 2);
30+
vx3 = v_loadu_f64(x + i + vstep * 3);
31+
vy0 = v_loadu_f64(y + i);
32+
vy1 = v_loadu_f64(y + i + vstep);
33+
vy2 = v_loadu_f64(y + i + vstep * 2);
34+
vy3 = v_loadu_f64(y + i + vstep * 3);
35+
36+
vt0 = v_mul_f64(__s, vy0);
37+
vt1 = v_mul_f64(__s, vy1);
38+
vt2 = v_mul_f64(__s, vy2);
39+
vt3 = v_mul_f64(__s, vy3);
40+
41+
vt0 = v_muladd_f64(__c, vx0, vt0);
42+
vt1 = v_muladd_f64(__c, vx1, vt1);
43+
vt2 = v_muladd_f64(__c, vx2, vt2);
44+
vt3 = v_muladd_f64(__c, vx3, vt3);
45+
46+
v_storeu_f64(x + i, vt0);
47+
v_storeu_f64(x + i + vstep, vt1);
48+
v_storeu_f64(x + i + vstep * 2, vt2);
49+
v_storeu_f64(x + i + vstep * 3, vt3);
50+
51+
vt0 = v_mul_f64(__s, vx0);
52+
vt1 = v_mul_f64(__s, vx1);
53+
vt2 = v_mul_f64(__s, vx2);
54+
vt3 = v_mul_f64(__s, vx3);
55+
56+
vt0 = v_mulsub_f64(__c, vy0, vt0);
57+
vt1 = v_mulsub_f64(__c, vy1, vt1);
58+
vt2 = v_mulsub_f64(__c, vy2, vt2);
59+
vt3 = v_mulsub_f64(__c, vy3, vt3);
60+
61+
v_storeu_f64(y + i, vt0);
62+
v_storeu_f64(y + i + vstep, vt1);
63+
v_storeu_f64(y + i + vstep * 2, vt2);
64+
v_storeu_f64(y + i + vstep * 3, vt3);
65+
}
66+
67+
for (; i < unrollx; i += vstep) {
68+
vx0 = v_loadu_f64(x + i);
69+
vy0 = v_loadu_f64(y + i);
70+
71+
vt0 = v_mul_f64(__s, vy0);
72+
vt0 = v_muladd_f64(__c, vx0, vt0);
73+
v_storeu_f64(x + i, vt0);
74+
75+
vt0 = v_mul_f64(__s, vx0);
76+
vt0 = v_mulsub_f64(__c, vy0, vt0);
77+
v_storeu_f64(y + i, vt0);
78+
}
79+
#else
1480
FLOAT f0, f1, f2, f3;
1581
FLOAT x0, x1, x2, x3;
1682
FLOAT g0, g1, g2, g3;
@@ -53,7 +119,7 @@ static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
53119
yp += 4;
54120
i += 4;
55121
}
56-
122+
#endif
57123
while (i < n) {
58124
FLOAT temp = c*x[i] + s*y[i];
59125
y[i] = c*y[i] - s*x[i];

0 commit comments

Comments
 (0)