3
3
/***************************
4
4
* Data Type
5
5
***************************/
6
+ #ifdef DOUBLE
7
+ typedef __m128d v_f32 ;
8
+ #else
6
9
typedef __m128 v_f32 ;
10
+ #endif
11
+
7
12
#define v_nlanes_f32 4
8
13
/***************************
9
14
* Arithmetic
10
15
***************************/
16
+ #ifdef DOUBLE
17
+ #define v_add_f32 _mm_add_pd
18
+ #define v_mul_f32 _mm_mul_pd
19
+ #else
11
20
#define v_add_f32 _mm_add_ps
12
21
#define v_mul_f32 _mm_mul_ps
22
+ #endif
13
23
#ifdef HAVE_FMA3
14
24
// multiply and add, a*b + c
15
- #define v_muladd_f32 _mm_fmadd_ps
25
+ #ifdef DOUBLE
26
+ #define v_muladd_f32 _mm_fmadd_pd
27
+ #else
28
+ #define v_muladd_f32 _mm_fmadd_ps
29
+ #endif
16
30
#elif defined(HAVE_FMA4 )
17
31
// multiply and add, a*b + c
18
- #define v_muladd_f32 _mm_macc_ps
32
+ #ifdef DOUBLE
33
+ #define v_muladd_f32 _mm_macc_pd
34
+ #else
35
+ #define v_muladd_f32 _mm_macc_ps
36
+ #endif
19
37
#else
20
38
// multiply and add, a*b + c
21
39
BLAS_FINLINE v_f32 v_muladd_f32 (v_f32 a , v_f32 b , v_f32 c )
22
40
{ return v_add_f32 (v_mul_f32 (a , b ), c ); }
23
41
#endif // HAVE_FMA3
24
42
43
+ // Horizontal add: Calculates the sum of all vector elements.
44
+ #ifdef DOUBLE
45
+ BLAS_FINLINE double v_sum_f32 (__m128d a )
46
+ {
47
+ #ifdef HAVE_SSE3
48
+ __m128d sum_halves = _mm_hadd_pd (a , a );
49
+ return _mm_cvtsd_f64 (_mm_hadd_pd (sum_halves , sum_halves ));
50
+ #else
51
+ __m128d t1 = _mm_movehl_pd (a , a );
52
+ __m128d t2 = _mm_add_pd (a , t1 );
53
+ __m128d t3 = _mm_shuffle_pd (t2 , t2 , 1 );
54
+ __m128d t4 = _mm_add_ss (t2 , t3 );
55
+ return _mm_cvtsd_f64 (t4 );
56
+ #endif
57
+ }
58
+ #else
25
59
// Horizontal add: Calculates the sum of all vector elements.
26
60
BLAS_FINLINE float v_sum_f32 (__m128 a )
27
61
{
@@ -36,11 +70,19 @@ BLAS_FINLINE float v_sum_f32(__m128 a)
36
70
return _mm_cvtss_f32 (t4 );
37
71
#endif
38
72
}
73
+ #endif
39
74
/***************************
40
75
* memory
41
76
***************************/
42
77
// unaligned load
78
+ #ifdef DOUBLE
79
+ #define v_loadu_f32 _mm_loadu_pd
80
+ #define v_storeu_f32 _mm_storeu_pd
81
+ #define v_setall_f32 (VAL ) _mm_set1_pd(VAL)
82
+ #define v_zero_f32 _mm_setzero_pd
83
+ #else
43
84
#define v_loadu_f32 _mm_loadu_ps
44
85
#define v_storeu_f32 _mm_storeu_ps
45
86
#define v_setall_f32 (VAL ) _mm_set1_ps(VAL)
46
- #define v_zero_f32 _mm_setzero_ps
87
+ #define v_zero_f32 _mm_setzero_ps
88
+ #endif
0 commit comments