Skip to content

Commit c9775f2

Browse files
committed
Improve casting for ints to match Numpy and add docs
1 parent 4eab0e0 commit c9775f2

File tree

5 files changed

+88
-23
lines changed

5 files changed

+88
-23
lines changed

doc/user_guide.rst

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,10 @@ Supported operators
188188

189189
*NumExpr* supports the set of operators listed below:
190190

191-
* Bitwise operators (and, or, not, xor): :code:`&, |, ~, ^`
191+
* Bitwise and logical operators (and, or, not, xor): :code:`&, |, ~, ^`
192192
* Comparison operators: :code:`<, <=, ==, !=, >=, >`
193193
* Unary arithmetic operators: :code:`-`
194-
* Binary arithmetic operators: :code:`+, -, *, /, **, %, <<, >>`
194+
* Binary arithmetic operators: :code:`+, -, *, /, //, **, %, <<, >>`
195195

196196

197197
Supported functions
@@ -203,22 +203,33 @@ The next are the current supported set:
203203
is true, number2 otherwise.
204204
* :code:`{isinf, isnan, isfinite}(float|complex): bool` -- returns element-wise True
205205
for ``inf`` or ``NaN``, ``NaN``, not ``inf`` respectively.
206+
* :code:`signbit(float|complex): bool` -- returns element-wise True if signbit is set
207+
False otherwise.
206208
* :code:`{sin,cos,tan}(float|complex): float|complex` -- trigonometric sine,
207209
cosine or tangent.
208210
* :code:`{arcsin,arccos,arctan}(float|complex): float|complex` -- trigonometric
209211
inverse sine, cosine or tangent.
210212
* :code:`arctan2(float1, float2): float` -- trigonometric inverse tangent of
211213
float1/float2.
214+
* :code:`hypot(float1, float2): float` -- Euclidean distance between float1, float2
215+
* :code:`nextafter(float1, float2): float` -- next representable floating-point value after
216+
float1 in direction of float2
217+
* :code:`copysign(float1, float2): float` -- return number with magnitude of float1 and
218+
sign of float2
219+
* :code:`{maximum,minimum}(float1, float2): float` -- return max/min of float1, float2
212220
* :code:`{sinh,cosh,tanh}(float|complex): float|complex` -- hyperbolic sine,
213221
cosine or tangent.
214222
* :code:`{arcsinh,arccosh,arctanh}(float|complex): float|complex` -- hyperbolic
215223
inverse sine, cosine or tangent.
216-
* :code:`{log,log10,log1p}(float|complex): float|complex` -- natural, base-10 and
224+
* :code:`{log,log10,log1p,log2}(float|complex): float|complex` -- natural, base-10 and
217225
log(1+x) logarithms.
218226
* :code:`{exp,expm1}(float|complex): float|complex` -- exponential and exponential
219227
minus one.
220228
* :code:`sqrt(float|complex): float|complex` -- square root.
221-
* :code:`abs(float|complex): float|complex` -- absolute value.
229+
* :code:`trunc(float): float` -- round towards zero
230+
* :code:`round(float|complex|int): float|complex|int` -- round to nearest integer (`rint`)
231+
* :code:`sign(float|complex|int): float|complex|int` -- return -1, 0, +1 depending on sign
232+
* :code:`abs(float|complex|int): float|complex|int` -- absolute value.
222233
* :code:`conj(complex): complex` -- conjugate value.
223234
* :code:`{real,imag}(complex): float` -- real or imaginary part of complex.
224235
* :code:`complex(float, float): complex` -- complex from real and imaginary

numexpr/bespoke_functions.hpp

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,43 @@ inline float signf(float x){
2323
return 0; // handles +0.0 and -0.0
2424
}
2525

26+
// round function for ints
27+
inline int rinti(int x) {return x;}
28+
inline long rintl(long x) {return x;}
29+
// abs function for ints
30+
inline int fabsi(int x) {return x<0 ? -x: x;}
31+
inline long fabsl(long x) {return x<0 ? -x: x;}
32+
// fmod function for ints
33+
inline int fmodi(int x, int y) {return (int)fmodf((float)x, (float)y);}
34+
inline long fmodl(long x, long y) {return (long)fmodf((long)x, (long)y);}
2635

2736
#ifdef USE_VML
37+
static void viRint(MKL_INT n, const int* x, int* dest)
38+
{
39+
memcpy(dest, x, n * sizeof(int)); // just copy x1 which is already int
40+
};
41+
42+
static void vlRint(MKL_INT n, const long* x, long* dest)
43+
{
44+
memcpy(dest, x, n * sizeof(long)); // just copy x1 which is already int
45+
};
46+
47+
static void viFabs(MKL_INT n, const int* x, int* dest)
48+
{
49+
MKL_INT j;
50+
for (j=0; j<n; j++) {
51+
dest[j] = x[j] < 0 ? -x[j]: x[j];
52+
};
53+
};
54+
55+
static void vlFabs(MKL_INT n, const long* x, long* dest)
56+
{
57+
MKL_INT j;
58+
for (j=0; j<n; j++) {
59+
dest[j] = x[j] < 0 ? -x[j]: x[j];
60+
};
61+
};
62+
2863
/* Fake vsConj function just for casting purposes inside numexpr */
2964
static void vsConj(MKL_INT n, const float* x1, float* dest)
3065
{
@@ -39,9 +74,31 @@ static void vsfmod(MKL_INT n, const float* x1, const float* x2, float* dest)
3974
{
4075
MKL_INT j;
4176
for(j=0; j < n; j++) {
42-
dest[j] = fmod(x1[j], x2[j]);
77+
dest[j] = fmodf(x1[j], x2[j]);
4378
};
4479
}
80+
static void vdfmod(MKL_INT n, const double* x1, const double* x2, double* dest)
81+
{
82+
MKL_INT j;
83+
for(j=0; j < n; j++) {
84+
dest[j] = fmod(x1[j], x2[j]);
85+
};
86+
};
87+
static void vifmod(MKL_INT n, const int* x1, const int* x2, int* dest)
88+
{
89+
MKL_INT j;
90+
for(j=0; j < n; j++) {
91+
dest[j] = fmodi(x1[j], x2[j]);
92+
};
93+
};
94+
static void vlfmod(MKL_INT n, const long* x1, const long* x2, long* dest)
95+
{
96+
MKL_INT j;
97+
for(j=0; j < n; j++) {
98+
dest[j] = fmodl(x1[j], x2[j]);
99+
};
100+
};
101+
45102
/* no isnan, isfinite, isinf or signbit in VML */
46103
static void vsIsfinite(MKL_INT n, const float* x1, bool* dest)
47104
{
@@ -134,15 +191,6 @@ static void vdConj(MKL_INT n, const double* x1, double* dest)
134191
};
135192
};
136193

137-
/* fmod not available in VML */
138-
static void vdfmod(MKL_INT n, const double* x1, const double* x2, double* dest)
139-
{
140-
MKL_INT j;
141-
for(j=0; j < n; j++) {
142-
dest[j] = fmod(x1[j], x2[j]);
143-
};
144-
};
145-
146194
/* various functions not available in VML */
147195
static void vzExpm1(MKL_INT n, const MKL_Complex16* x1, MKL_Complex16* dest)
148196
{

numexpr/expressions.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,10 @@ def function(*args):
186186
return ConstantNode(func(*[x.value for x in args]))
187187
kind = commonKind(args)
188188
if kind in ('int', 'long'):
189-
# Exception for following NumPy casting rules
190-
#FIXME: this is not always desirable. The following
191-
# functions which return ints (for int inputs) on numpy
192-
# but not on numexpr: copy, abs, fmod, ones_like
193-
kind = 'double'
189+
if func.__name__ not in ('copy', 'abs', 'fmod', 'ones_like', 'round', 'sign'):
190+
# except for these special functions (which return ints for int inputs in NumPy)
191+
# just do a cast to double
192+
kind = 'double'
194193
else:
195194
# Apply regular casting rules
196195
if minkind and kind_rank.index(minkind) > kind_rank.index(kind):

numexpr/functions.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,9 @@ FUNC_BC(FUNC_BC_LAST, NULL, NULL, NULL)
212212
#define ELIDE_FUNC_II
213213
#define FUNC_II(...)
214214
#endif
215-
FUNC_II(FUNC_SIGN_II, "sign_ii", signi, viSign)
215+
FUNC_II(FUNC_SIGN_II, "sign_ii", signi, viSign)
216+
FUNC_II(FUNC_ROUND_II, "round_ii", rinti, viRint)
217+
FUNC_II(FUNC_ABS_II, "absolute_ii", fabsi, viFabs)
216218
FUNC_II(FUNC_II_LAST, NULL, NULL, NULL)
217219
#ifdef ELIDE_FUNC_II
218220
#undef ELIDE_FUNC_II
@@ -223,7 +225,9 @@ FUNC_II(FUNC_II_LAST, NULL, NULL, NULL)
223225
#define ELIDE_FUNC_LL
224226
#define FUNC_LL(...)
225227
#endif
226-
FUNC_LL(FUNC_SIGN_LL, "sign_LL", signl, vlSign)
228+
FUNC_LL(FUNC_SIGN_LL, "sign_ll", signl, vlSign)
229+
FUNC_LL(FUNC_ROUND_LL, "round_ll", rintl, vlRint)
230+
FUNC_LL(FUNC_ABS_LL, "absolute_ll", fabsl, vlFabs)
227231
FUNC_LL(FUNC_LL_LAST, NULL, NULL, NULL)
228232
#ifdef ELIDE_FUNC_LL
229233
#undef ELIDE_FUNC_LL

numexpr/tests/test_numexpr.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,14 +487,17 @@ def test_maximum_minimum(self):
487487
assert_array_equal(evaluate("maximum(x,y)"), maximum(x,y))
488488
assert_array_equal(evaluate("minimum(x,y)"), minimum(x,y))
489489

490-
def test_sign(self):
491-
for dtype in [float, double, int, np.int64, complex]:
490+
def test_sign_round(self):
491+
for dtype in [float, double, np.int32, np.int64, complex]:
492492
x = arange(10, dtype=dtype)
493493
y = 2 * arange(10, dtype=dtype)[::-1]
494494
r = x-y
495495
if not np.issubdtype(dtype, np.integer):
496496
r[-1] = np.nan
497+
assert evaluate("round(r)").dtype == round(r).dtype
498+
assert evaluate("sign(r)").dtype == sign(r).dtype
497499
assert_array_equal(evaluate("sign(r)"), sign(r))
500+
assert_array_equal(evaluate("round(r)"), round(r))
498501

499502
def test_rational_expr(self):
500503
a = arange(1e6)

0 commit comments

Comments
 (0)