Skip to content

Commit 15bec37

Browse files
committed
Remove boolean handling and unify LazyExpr/NDArray expr handling
1 parent d65ae78 commit 15bec37

File tree

2 files changed

+100
-142
lines changed

2 files changed

+100
-142
lines changed

src/blosc2/lazyexpr.py

Lines changed: 16 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@
4545
get_chunks_idx,
4646
get_intersecting_chunks,
4747
is_inside_new_expr,
48+
local_ufunc_map,
4849
process_key,
50+
ufunc_map,
51+
ufunc_map_1param,
4952
)
5053

5154
if not blosc2.IS_WASM:
@@ -150,6 +153,7 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
150153
"log",
151154
"log10",
152155
"log1p",
156+
"log2",
153157
"conj",
154158
"real",
155159
"imag",
@@ -169,6 +173,16 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
169173
"isnan",
170174
"isfinite",
171175
"isinf",
176+
"nextafter",
177+
"copysign",
178+
"hypot",
179+
"maximum",
180+
"minimum",
181+
"floor",
182+
"ceil",
183+
"trunc",
184+
"signbit",
185+
"round",
172186
]
173187

174188
# Gather all callable functions in numpy
@@ -2512,52 +2526,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
25122526
if method != "__call__":
25132527
return NotImplemented
25142528

2515-
ufunc_map = {
2516-
np.add: "+",
2517-
np.subtract: "-",
2518-
np.multiply: "*",
2519-
np.divide: "/",
2520-
np.true_divide: "/",
2521-
np.power: "**",
2522-
np.less: "<",
2523-
np.less_equal: "<=",
2524-
np.greater: ">",
2525-
np.greater_equal: ">=",
2526-
np.equal: "==",
2527-
np.not_equal: "!=",
2528-
np.bitwise_and: "&",
2529-
np.bitwise_or: "|",
2530-
np.bitwise_xor: "^",
2531-
}
2532-
2533-
ufunc_map_1param = {
2534-
np.sqrt: "sqrt",
2535-
np.sin: "sin",
2536-
np.cos: "cos",
2537-
np.tan: "tan",
2538-
np.arcsin: "arcsin",
2539-
np.arccos: "arccos",
2540-
np.arctan: "arctan",
2541-
np.sinh: "sinh",
2542-
np.cosh: "cosh",
2543-
np.tanh: "tanh",
2544-
np.arcsinh: "arcsinh",
2545-
np.arccosh: "arccosh",
2546-
np.arctanh: "arctanh",
2547-
np.exp: "exp",
2548-
np.expm1: "expm1",
2549-
np.log: "log",
2550-
np.log10: "log10",
2551-
np.log1p: "log1p",
2552-
np.abs: "abs",
2553-
np.conj: "conj",
2554-
np.real: "real",
2555-
np.imag: "imag",
2556-
np.bitwise_not: "~",
2557-
np.isnan: "isnan",
2558-
np.isfinite: "isfinite",
2559-
np.isinf: "isinf",
2560-
}
2529+
if ufunc in local_ufunc_map:
2530+
return local_ufunc_map[ufunc](*inputs)
25612531

25622532
if ufunc in ufunc_map:
25632533
value = inputs[0] if inputs[1] is self else inputs[1]

src/blosc2/ndarray.py

Lines changed: 84 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,90 @@
3636

3737
# NumPy version and a convenient boolean flag
3838
NUMPY_GE_2_0 = np.__version__ >= "2.0"
39+
# handle different numpy versions
40+
if NUMPY_GE_2_0: # array-api compliant
41+
nplshift = np.bitwise_left_shift
42+
nprshift = np.bitwise_right_shift
43+
npbinvert = np.bitwise_invert
44+
else: # not array-api compliant
45+
nplshift = np.left_shift
46+
nprshift = np.right_shift
47+
npbinvert = np.bitwise_not
48+
49+
# These functions in ufunc_map in ufunc_map_1param are implemented in numexpr and so we call
50+
# those instead (since numexpr uses multithreading it is faster)
51+
ufunc_map = {
52+
np.add: "+",
53+
np.subtract: "-",
54+
np.multiply: "*",
55+
np.divide: "/",
56+
np.true_divide: "/",
57+
np.floor_divide: "//",
58+
np.power: "**",
59+
np.less: "<",
60+
np.less_equal: "<=",
61+
np.greater: ">",
62+
np.greater_equal: ">=",
63+
np.equal: "==",
64+
np.not_equal: "!=",
65+
np.bitwise_and: "&",
66+
np.bitwise_or: "|",
67+
np.bitwise_xor: "^",
68+
np.arctan2: "arctan2",
69+
nplshift: "<<", # nplshift selected above according to numpy version
70+
nprshift: ">>", # nprshift selected above according to numpy version
71+
np.remainder: "%",
72+
np.nextafter: "nextafter",
73+
np.copysign: "copysign",
74+
np.hypot: "hypot",
75+
np.maximum: "maximum",
76+
np.minimum: "minimum",
77+
}
78+
79+
# implemented in numexpr
80+
ufunc_map_1param = {
81+
np.sqrt: "sqrt",
82+
np.sin: "sin",
83+
np.cos: "cos",
84+
np.tan: "tan",
85+
np.arcsin: "arcsin",
86+
np.arccos: "arccos",
87+
np.arctan: "arctan",
88+
np.sinh: "sinh",
89+
np.cosh: "cosh",
90+
np.tanh: "tanh",
91+
np.arcsinh: "arcsinh",
92+
np.arccosh: "arccosh",
93+
np.arctanh: "arctanh",
94+
np.exp: "exp",
95+
np.expm1: "expm1",
96+
np.log: "log",
97+
np.log10: "log10",
98+
np.log1p: "log1p",
99+
np.log2: "log2",
100+
np.abs: "abs",
101+
np.conj: "conj",
102+
np.real: "real",
103+
np.imag: "imag",
104+
npbinvert: "~", # npbinvert selected above according to numpy version
105+
np.isnan: "isnan",
106+
np.isfinite: "isfinite",
107+
np.isinf: "isinf",
108+
np.floor: "floor",
109+
np.ceil: "ceil",
110+
np.trunc: "trunc",
111+
np.signbit: "signbit",
112+
np.round: "round",
113+
}
114+
115+
# implemented in python-blosc2
116+
local_ufunc_map = {
117+
np.logaddexp: blosc2.logaddexp,
118+
np.logical_not: blosc2.logical_not,
119+
np.logical_and: blosc2.logical_and,
120+
np.logical_or: blosc2.logical_or,
121+
np.logical_xor: blosc2.logical_xor,
122+
}
39123

40124

41125
@runtime_checkable
@@ -2928,17 +3012,6 @@ def chunkwise_logaddexp(inputs, output, offset):
29283012
return blosc2.lazyudf(chunkwise_logaddexp, (x1, x2), dtype=dtype, shape=x1.shape)
29293013

29303014

2931-
# handle different numpy versions
2932-
if NUMPY_GE_2_0: # array-api compliant
2933-
nplshift = np.bitwise_left_shift
2934-
nprshift = np.bitwise_right_shift
2935-
npbinvert = np.bitwise_invert
2936-
else: # not array-api compliant
2937-
nplshift = np.left_shift
2938-
nprshift = np.right_shift
2939-
npbinvert = np.bitwise_not
2940-
2941-
29423015
class Operand:
29433016
"""Base class for all operands in expressions."""
29443017

@@ -2989,92 +3062,12 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
29893062
if method != "__call__":
29903063
return NotImplemented
29913064

2992-
# These functions in ufunc_map in ufunc_map_1param are implemented in numexpr and so we call
2993-
# those instead (since numexpr uses multithreading it is faster)
2994-
ufunc_map = {
2995-
np.add: "+",
2996-
np.subtract: "-",
2997-
np.multiply: "*",
2998-
np.divide: "/",
2999-
np.true_divide: "/",
3000-
np.floor_divide: "//",
3001-
np.power: "**",
3002-
np.less: "<",
3003-
np.less_equal: "<=",
3004-
np.greater: ">",
3005-
np.greater_equal: ">=",
3006-
np.equal: "==",
3007-
np.not_equal: "!=",
3008-
np.bitwise_and: "&",
3009-
np.bitwise_or: "|",
3010-
np.bitwise_xor: "^",
3011-
np.arctan2: "arctan2",
3012-
nplshift: "<<", # nplshift selected above according to numpy version
3013-
nprshift: ">>", # nprshift selected above according to numpy version
3014-
np.remainder: "%",
3015-
np.nextafter: "nextafter",
3016-
np.copysign: "copysign",
3017-
np.hypot: "hypot",
3018-
np.maximum: "maximum",
3019-
np.minimum: "minimum",
3020-
}
3021-
3022-
# implemented in numexpr
3023-
ufunc_map_1param = {
3024-
np.sqrt: "sqrt",
3025-
np.sin: "sin",
3026-
np.cos: "cos",
3027-
np.tan: "tan",
3028-
np.arcsin: "arcsin",
3029-
np.arccos: "arccos",
3030-
np.arctan: "arctan",
3031-
np.sinh: "sinh",
3032-
np.cosh: "cosh",
3033-
np.tanh: "tanh",
3034-
np.arcsinh: "arcsinh",
3035-
np.arccosh: "arccosh",
3036-
np.arctanh: "arctanh",
3037-
np.exp: "exp",
3038-
np.expm1: "expm1",
3039-
np.log: "log",
3040-
np.log10: "log10",
3041-
np.log1p: "log1p",
3042-
np.log2: "log2",
3043-
np.abs: "abs",
3044-
np.conj: "conj",
3045-
np.real: "real",
3046-
np.imag: "imag",
3047-
npbinvert: "~", # npbinvert selected above according to numpy version
3048-
np.isnan: "isnan",
3049-
np.isfinite: "isfinite",
3050-
np.isinf: "isinf",
3051-
np.floor: "floor",
3052-
np.ceil: "ceil",
3053-
np.trunc: "trunc",
3054-
np.signbit: "signbit",
3055-
np.round: "round",
3056-
}
3057-
3058-
# implemented in python-blosc2
3059-
local_ufunc_map = {
3060-
np.logaddexp: logaddexp,
3061-
np.logical_not: logical_not,
3062-
np.logical_and: logical_and,
3063-
np.logical_or: logical_or,
3064-
np.logical_xor: logical_xor,
3065-
}
30663065
if ufunc in local_ufunc_map:
30673066
return local_ufunc_map[ufunc](*inputs)
30683067

30693068
if ufunc in ufunc_map:
30703069
value = inputs[0] if inputs[1] is self else inputs[1]
30713070
_check_allowed_dtypes(value)
3072-
# catch special case of multiplying two bools (not implemented in numexpr)
3073-
if ufunc_map[ufunc] == "*" and blosc2.result_type(value, self) == blosc2.bool_:
3074-
return blosc2.LazyExpr(new_op=(value, "&", self))
3075-
# catch special case of adding two bools (not implemented in numexpr)
3076-
if ufunc_map[ufunc] == "+" and blosc2.result_type(value, self) == blosc2.bool_:
3077-
return blosc2.LazyExpr(new_op=(value, "|", self))
30783071
return blosc2.LazyExpr(new_op=(value, ufunc_map[ufunc], self))
30793072

30803073
if ufunc in ufunc_map_1param:
@@ -3086,8 +3079,6 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
30863079

30873080
def __add__(self, value: int | float | blosc2.Array, /) -> blosc2.LazyExpr:
30883081
_check_allowed_dtypes(value)
3089-
if blosc2.result_type(value, self) == blosc2.bool_:
3090-
return blosc2.LazyExpr(new_op=(value, "|", self))
30913082
return blosc2.LazyExpr(new_op=(self, "+", value))
30923083

30933084
def __iadd__(self, value: int | float | blosc2.Array, /) -> blosc2.LazyExpr:
@@ -3123,9 +3114,6 @@ def __rsub__(self, value: int | float | blosc2.Array, /) -> blosc2.LazyExpr:
31233114
@is_documented_by(multiply)
31243115
def __mul__(self, value: int | float | blosc2.Array, /) -> blosc2.LazyExpr:
31253116
_check_allowed_dtypes(value)
3126-
# catch special case of multiplying two bools (not implemented in numexpr)
3127-
if blosc2.result_type(value, self) == blosc2.bool_:
3128-
return blosc2.LazyExpr(new_op=(value, "&", self))
31293117
return blosc2.LazyExpr(new_op=(self, "*", value))
31303118

31313119
def __imul__(self, value: int | float | blosc2.Array, /) -> blosc2.LazyExpr:

0 commit comments

Comments
 (0)