@@ -17,34 +17,34 @@ namespace {
17
17
18
18
template <int N>
19
19
CM_NODEBUG CM_INLINE mask<N> check_is_nan_or_inf (vector<double , N> q) {
20
- vector<uint32_t , 2 * N> q_split = q.template format <uint32_t >();
20
+ vector<uint32_t , 2 *N> q_split = q.template format <uint32_t >();
21
21
vector<uint32_t , N> q_hi = q_split.template select <N, 2 >(1 );
22
22
return (q_hi >= exp_32bitmask);
23
23
}
24
24
25
25
template <int N>
26
26
CM_NODEBUG CM_INLINE vector<uint32_t , N> get_exp (vector<double , N> x) {
27
- vector<uint32_t , 2 * N> x_split = x.template format <uint32_t >();
27
+ vector<uint32_t , 2 *N> x_split = x.template format <uint32_t >();
28
28
vector<uint32_t , N> x_hi = x_split.template select <N, 2 >(1 );
29
29
return (x_hi >> exp_shift) & exp_mask;
30
30
}
31
31
32
32
template <int N>
33
33
CM_NODEBUG CM_INLINE vector<uint32_t , N> get_sign (vector<double , N> x) {
34
- vector<uint32_t , 2 * N> x_split = x.template format <uint32_t >();
34
+ vector<uint32_t , 2 *N> x_split = x.template format <uint32_t >();
35
35
vector<uint32_t , N> x_hi = x_split.template select <N, 2 >(1 );
36
36
return x_hi & sign_32bit;
37
37
}
38
38
39
39
template <int N> CM_NODEBUG CM_INLINE mask<N> is_denormal (vector<double , N> x) {
40
- vector<uint32_t , 2 * N> x_int = x.template format <uint32_t >();
40
+ vector<uint32_t , 2 *N> x_int = x.template format <uint32_t >();
41
41
vector<uint32_t , N> x_hi = x_int.template select <N, 2 >(1 );
42
42
return x_hi < min_sign_exp;
43
43
}
44
44
45
45
template <int N>
46
46
CM_NODEBUG CM_INLINE vector<uint32_t , N> sep_exp (vector<double , N> x) {
47
- vector<uint32_t , 2 * N> x_int = x.template format <uint32_t >();
47
+ vector<uint32_t , 2 *N> x_int = x.template format <uint32_t >();
48
48
vector<uint32_t , N> x_hi = x_int.template select <N, 2 >(1 );
49
49
vector<uint32_t , N> res = (x_hi >> exp_shift) - exp_bias;
50
50
return res >> 1 ;
@@ -84,8 +84,9 @@ CM_NODEBUG CM_INLINE vector<double, N> rsqrt_float(vector<double, N> x) {
84
84
}
85
85
86
86
template <int N>
87
- CM_NODEBUG CM_INLINE vector<double , N> uint64_sub_hi (vector<double , N> x, vector<uint32_t , N> hi) {
88
- vector<uint32_t , 2 * N> ex_mx_int = 0 ;
87
+ CM_NODEBUG CM_INLINE vector<double , N> uint64_sub_hi (vector<double , N> x,
88
+ vector<uint32_t , N> hi) {
89
+ vector<uint32_t , 2 *N> ex_mx_int = 0 ;
89
90
ex_mx_int.template select <N, 2 >(1 ) = hi;
90
91
vector<uint64_t , N> ex_u64 = ex_mx_int.template format <uint64_t >();
91
92
vector<uint64_t , N> mx_u64 = x.template format <uint64_t >();
@@ -163,9 +164,10 @@ CM_NODEBUG CM_INLINE vector<double, N> sqrt_special(vector<double, N> a) {
163
164
}
164
165
165
166
template <int N>
166
- CM_NODEBUG CM_INLINE vector<double , N> calc_sqrt (vector<double , N> x, mask<N> special) {
167
+ CM_NODEBUG CM_INLINE vector<double , N> calc_sqrt (vector<double , N> x,
168
+ mask<N> special) {
167
169
// Now start the SQRT computation
168
- // Use math.rsqtm (emulated here)
170
+ // Use math.rsqtm (emulated here)
169
171
vector<double , N> y0 = math_rsqt_dp (x);
170
172
// predicate is set for 0, neg a, Inf, NaN inputs
171
173
y0.merge (sqrt_special (x), special);
@@ -174,7 +176,8 @@ CM_NODEBUG CM_INLINE vector<double, N> calc_sqrt(vector<double, N> x, mask<N> sp
174
176
}
175
177
176
178
template <int N>
177
- CM_NODEBUG CM_INLINE vector<double , N> invert_calc (vector<double , N> a, vector<double , N> y0) {
179
+ CM_NODEBUG CM_INLINE vector<double , N> invert_calc (vector<double , N> a,
180
+ vector<double , N> y0) {
178
181
// IEEE SQRT computes H0 = 0.5*y0 (can be skipped)
179
182
// Step 3: S0 = a*y0
180
183
vector<double , N> S0 = a * y0;
@@ -235,15 +238,17 @@ __vc_builtin_rsqrt_f64__rte_(double a) {
235
238
return __impl_rsqrt_f64 (va)[0 ];
236
239
}
237
240
238
- #define FREM (WIDTH ) \
241
+ #define RSQRT (WIDTH ) \
239
242
CM_NODEBUG CM_NOINLINE extern " C" cl_vector<double , WIDTH> \
240
- __vc_builtin_rsqrt_v##WIDTH##f64__rte_(cl_vector<double , WIDTH> a) { \
243
+ __vc_builtin_rsqrt_v##WIDTH##f64__rte_(cl_vector<double , WIDTH> a) { \
241
244
vector<double , WIDTH> va{a}; \
242
- auto r = __impl_rsqrt_f64 (va); \
245
+ auto r = __impl_rsqrt_f64 (va); \
243
246
return r.cl_vector (); \
244
247
}
245
248
246
- FREM (1 )
247
- FREM(2 )
248
- FREM(4 )
249
- FREM(8 )
249
+ RSQRT (1 )
250
+ RSQRT(2 )
251
+ RSQRT(4 )
252
+ RSQRT(8 )
253
+ RSQRT(16 )
254
+ RSQRT(32 )
0 commit comments