Skip to content

Commit 521b3af

Browse files
committed
SIMD: Fix impl of intrinsic npyv_ceil_f32 on armv7/neon
1 parent 6ccad06 commit 521b3af

File tree

1 file changed

+31
-10
lines changed
  • numpy/core/src/common/simd/neon

1 file changed

+31
-10
lines changed

numpy/core/src/common/simd/neon/math.h

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,16 @@ NPY_FINLINE npyv_f32 npyv_recip_f32(npyv_f32 a)
8888
#define npyv_max_f64 vmaxq_f64
8989
// Maximum, supports IEEE floating-point arithmetic (IEC 60559),
9090
// - If one of the two vectors contains NaN, the equivalent element of the other vector is set
91-
// - Only if both corresponded elements are NaN, NaN is set.
91+
// - Only if both corresponded elements are NaN, NaN is set.
9292
#ifdef NPY_HAVE_ASIMD
9393
#define npyv_maxp_f32 vmaxnmq_f32
9494
#else
9595
NPY_FINLINE npyv_f32 npyv_maxp_f32(npyv_f32 a, npyv_f32 b)
96-
{
96+
{
9797
npyv_u32 nn_a = vceqq_f32(a, a);
9898
npyv_u32 nn_b = vceqq_f32(b, b);
9999
return vmaxq_f32(vbslq_f32(nn_a, a, b), vbslq_f32(nn_b, b, a));
100-
}
100+
}
101101
#endif
102102
#if NPY_SIMD_F64
103103
#define npyv_maxp_f64 vmaxnmq_f64
@@ -123,16 +123,16 @@ NPY_FINLINE npyv_s64 npyv_max_s64(npyv_s64 a, npyv_s64 b)
123123
#define npyv_min_f64 vminq_f64
124124
// Minimum, supports IEEE floating-point arithmetic (IEC 60559),
125125
// - If one of the two vectors contains NaN, the equivalent element of the other vector is set
126-
// - Only if both corresponded elements are NaN, NaN is set.
126+
// - Only if both corresponded elements are NaN, NaN is set.
127127
#ifdef NPY_HAVE_ASIMD
128128
#define npyv_minp_f32 vminnmq_f32
129129
#else
130130
NPY_FINLINE npyv_f32 npyv_minp_f32(npyv_f32 a, npyv_f32 b)
131-
{
131+
{
132132
npyv_u32 nn_a = vceqq_f32(a, a);
133133
npyv_u32 nn_b = vceqq_f32(b, b);
134134
return vminq_f32(vbslq_f32(nn_a, a, b), vbslq_f32(nn_b, b, a));
135-
}
135+
}
136136
#endif
137137
#if NPY_SIMD_F64
138138
#define npyv_minp_f64 vminnmq_f64
@@ -159,10 +159,31 @@ NPY_FINLINE npyv_s64 npyv_min_s64(npyv_s64 a, npyv_s64 b)
159159
#else
160160
NPY_FINLINE npyv_f32 npyv_ceil_f32(npyv_f32 a)
161161
{
162-
npyv_f32 conv_trunc = vcvtq_f32_s32(vcvtq_s32_f32(a));
163-
npyv_f32 conv_trunc_add_one = npyv_add_f32(conv_trunc, vdupq_n_f32(1.0f));
164-
npyv_u32 mask = vcltq_f32(conv_trunc, a);
165-
return vbslq_f32(mask, conv_trunc, conv_trunc_add_one);
162+
const npyv_s32 szero = vreinterpretq_s32_f32(vdupq_n_f32(-0.0f));
163+
const npyv_u32 one = vreinterpretq_u32_f32(vdupq_n_f32(1.0f));
164+
const npyv_s32 max_int = vdupq_n_s32(0x7fffffff);
165+
/**
166+
* On armv7, vcvtq.f32 handles special cases as follows:
167+
* NaN return 0
168+
* +inf or +outrange return 0x80000000(-0.0f)
169+
* -inf or -outrange return 0x7fffffff(nan)
170+
*/
171+
npyv_s32 roundi = vcvtq_s32_f32(a);
172+
npyv_f32 round = vcvtq_f32_s32(roundi);
173+
npyv_f32 ceil = vaddq_f32(round, vreinterpretq_f32_u32(
174+
vandq_u32(vcltq_f32(round, a), one))
175+
);
176+
// respect signed zero, e.g. -0.5 -> -0.0
177+
npyv_f32 rzero = vreinterpretq_f32_s32(vorrq_s32(
178+
vreinterpretq_s32_f32(ceil),
179+
vandq_s32(vreinterpretq_s32_f32(a), szero)
180+
));
181+
// if nan or overflow return a
182+
npyv_u32 nnan = npyv_notnan_f32(a);
183+
npyv_u32 overflow = vorrq_u32(
184+
vceqq_s32(roundi, szero), vceqq_s32(roundi, max_int)
185+
);
186+
return vbslq_f32(vbicq_u32(nnan, overflow), rzero, a);
166187
}
167188
#endif
168189
#if NPY_SIMD_F64

0 commit comments

Comments
 (0)