Skip to content

Commit ab62b9f

Browse files
author
Raghuveer Devulapalli
committed
MAINT: Use hwy gather/scatter and hwy macros
1 parent 7d6cc6d commit ab62b9f

File tree

1 file changed

+11
-40
lines changed

1 file changed

+11
-40
lines changed

numpy/_core/src/umath/loops_trigonometric.dispatch.cpp

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include "numpy/npy_math.h"
21
#include "simd/simd.h"
32
#include "loops_utils.h"
43
#include "loops.h"
@@ -31,7 +30,7 @@ namespace hn = hwy::HWY_NAMESPACE;
3130
* elements or when there's no native FUSED support instead of fallback to libc
3231
*/
3332

34-
#if NPY_SIMD_FMA3 // native support
33+
#if HWY_NATIVE_FMA // native support
3534
typedef enum
3635
{
3736
SIMD_COMPUTE_SIN,
@@ -44,7 +43,7 @@ using vec_f32 = hn::Vec<decltype(f32)>;
4443
using vec_s32 = hn::Vec<decltype(s32)>;
4544
using opmask_t = hn::Mask<decltype(f32)>;
4645

47-
NPY_FINLINE HWY_ATTR vec_f32
46+
HWY_INLINE HWY_ATTR vec_f32
4847
simd_range_reduction_f32(vec_f32& x, vec_f32& y, const vec_f32& c1, const vec_f32& c2, const vec_f32& c3)
4948
{
5049
vec_f32 reduced_x = hn::MulAdd(y, c1, x);
@@ -53,7 +52,7 @@ simd_range_reduction_f32(vec_f32& x, vec_f32& y, const vec_f32& c1, const vec_f3
5352
return reduced_x;
5453
}
5554

56-
NPY_FINLINE HWY_ATTR vec_f32
55+
HWY_INLINE HWY_ATTR vec_f32
5756
simd_cosine_poly_f32(vec_f32& x2)
5857
{
5958
const vec_f32 invf8 = hn::Set(f32, 0x1.98e616p-16f);
@@ -73,7 +72,7 @@ simd_cosine_poly_f32(vec_f32& x2)
7372
* Maximum ULP across all 32-bit floats = 0.647
7473
* Polynomial approximation based on unpublished work by T. Myklebust
7574
*/
76-
NPY_FINLINE HWY_ATTR vec_f32
75+
HWY_INLINE HWY_ATTR vec_f32
7776
simd_sine_poly_f32(vec_f32& x, vec_f32& x2)
7877
{
7978
const vec_f32 invf9 = hn::Set(f32, 0x1.7d3bbcp-19f);
@@ -89,26 +88,6 @@ simd_sine_poly_f32(vec_f32& x, vec_f32& x2)
8988
return r;
9089
}
9190

92-
NPY_FINLINE HWY_ATTR vec_f32
93-
GatherIndexN(const float* src, npy_intp ssrc, npy_intp len)
94-
{
95-
float temp[hn::Lanes(f32)] = { 0.0f };
96-
for (auto ii = 0; ii < std::min(len, (npy_intp)hn::Lanes(f32)); ++ii) {
97-
temp[ii] = src[ii * ssrc];
98-
}
99-
return hn::LoadU(f32, temp);
100-
}
101-
102-
NPY_FINLINE HWY_ATTR void
103-
ScatterIndexN(vec_f32 vec, float* dst, npy_intp sdst, npy_intp len)
104-
{
105-
float temp[hn::Lanes(f32)];
106-
hn::StoreU(vec, f32, temp);
107-
for (auto ii = 0; ii < std::min(len, (npy_intp)hn::Lanes(f32)); ++ii) {
108-
dst[ii * sdst] = temp[ii];
109-
}
110-
}
111-
11291
static void HWY_ATTR SIMD_MSVC_NOINLINE
11392
simd_sincos_f32(const float *src, npy_intp ssrc, float *dst, npy_intp sdst,
11493
npy_intp len, SIMD_TRIG_OP trig_op)
@@ -130,23 +109,15 @@ simd_sincos_f32(const float *src, npy_intp ssrc, float *dst, npy_intp sdst,
130109
const vec_f32 max_cody = hn::Set(f32, max_codi);
131110

132111
const int lanes = hn::Lanes(f32);
133-
//npy_intp load_index[lanes/2];
134-
//for (auto i = 0; i < lanes; ++i) {
135-
// load_index[i] = i * ssrc;
136-
//}
137-
//vec_s32 vec_lindex = hn::LoadU(s32, load_index);
138-
//npy_intp store_index[lanes/2];
139-
//for (auto i = 0; i < lanes; ++i) {
140-
// store_index[i] = i * sdst;
141-
//}
142-
//vec_s32 vec_sindex = hn::LoadU(s32, store_index);
112+
const vec_s32 src_index = hn::Mul(hn::Iota(s32, 0), hn::Set(s32, ssrc));
113+
const vec_s32 dst_index = hn::Mul(hn::Iota(s32, 0), hn::Set(s32, sdst));
143114

144115
for (; len > 0; len -= lanes, src += ssrc*lanes, dst += sdst*lanes) {
145116
vec_f32 x_in;
146117
if (ssrc == 1) {
147118
x_in = hn::LoadN(f32, src, len);
148119
} else {
149-
x_in = GatherIndexN(src, ssrc, len);
120+
x_in = hn::GatherIndexN(f32, src, src_index, len);
150121
}
151122
opmask_t nnan_mask = hn::Not(hn::IsNaN(x_in));
152123
// Eliminate NaN to avoid FP invalid exception
@@ -191,7 +162,7 @@ simd_sincos_f32(const float *src, npy_intp ssrc, float *dst, npy_intp sdst,
191162
if (sdst == 1) {
192163
hn::StoreN(cos, f32, dst, len);
193164
} else {
194-
ScatterIndexN(cos, dst, sdst, len);
165+
hn::ScatterIndexN(cos, f32, dst, dst_index, len);
195166
}
196167
}
197168
if (!hn::AllTrue(f32, simd_mask)) {
@@ -221,7 +192,7 @@ simd_sincos_f32(const float *src, npy_intp ssrc, float *dst, npy_intp sdst,
221192
npyv_cleanup();
222193
}
223194
}
224-
#endif // NPY_SIMD_FMA3
195+
#endif // HWY_NATIVE_FMA
225196

226197
/* Disable SIMD code sin/cos f64 and revert to libm: see
227198
* https://mail.python.org/archives/list/[email protected]/thread/C6EYZZSR4EWGVKHAZXLE7IBILRMNVK7L/
@@ -242,7 +213,7 @@ DISPATCH_DOUBLE_FUNC(cos)
242213
NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(FLOAT_sin)
243214
(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(data))
244215
{
245-
#if NPY_SIMD_F32 && NPY_SIMD_FMA3
216+
#if HWY_NATIVE_FMA
246217
const npy_float *src = (npy_float*)args[0];
247218
npy_float *dst = (npy_float*)args[1];
248219

@@ -271,7 +242,7 @@ NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(FLOAT_sin)
271242
NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(FLOAT_cos)
272243
(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(data))
273244
{
274-
#if NPY_SIMD_F32 && NPY_SIMD_FMA3
245+
#if HWY_NATIVE_FMA
275246
const npy_float *src = (npy_float*)args[0];
276247
npy_float *dst = (npy_float*)args[1];
277248

0 commit comments

Comments
 (0)