Skip to content

Commit c59ba2e

Browse files
committed
Override NUMBA fastmath flag in IsNan dispatch
1 parent 8cc489b commit c59ba2e

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Clip,
2424
Composite,
2525
Identity,
26+
IsNan,
2627
Mul,
2728
Reciprocal,
2829
ScalarOp,
@@ -137,7 +138,8 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
137138

138139
return numba_basic.numba_njit(
139140
signature,
140-
fastmath=config.numba__fastmath,
141+
# numba always returns False if fastmath=True # https://github.com/numba/numba/issues/9383
142+
fastmath=False if isinstance(op, IsNan) else config.numba__fastmath,
141143
# Functions that call a function pointer can't be cached
142144
cache=False,
143145
)(scalar_op_fn)

tests/link/numba/test_scalar.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.graph.basic import Constant
1010
from pytensor.graph.fg import FunctionGraph
1111
from pytensor.scalar.basic import Composite
12+
from pytensor.tensor import tensor
1213
from pytensor.tensor.elemwise import Elemwise
1314
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value
1415

@@ -140,3 +141,14 @@ def test_reciprocal(v, dtype):
140141
if not isinstance(i, SharedVariable | Constant)
141142
],
142143
)
144+
145+
146+
@pytest.mark.parametrize("dtype", ("complex64", "float64", "float32"))
147+
def test_isnan(dtype):
148+
# Testing with tensor just to make sure Elemwise does not revert the scalar behavior of fastmath
149+
x = tensor(shape=(2,), dtype=dtype)
150+
out = pt.isnan(x)
151+
compare_numba_and_py(
152+
([x], [out]),
153+
[np.array([1, 0], dtype=dtype)],
154+
)

0 commit comments

Comments
 (0)