Skip to content

Commit e2b4c65

Browse files
committed
Use more specific Numba fastmath flags everywhere
1 parent 8cc489b commit e2b4c65

File tree

7 files changed

+56
-80
lines changed

7 files changed

+56
-80
lines changed

doc/extending/creating_a_numba_jax_op.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,13 @@ Here's an example for the `CumOp`\ `Op`:
358358
if mode == "add":
359359
if axis is None or ndim == 1:
360360
361-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
361+
@numba_basic.numba_njit()
362362
def cumop(x):
363363
return np.cumsum(x)
364364
365365
else:
366366
367-
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
367+
@numba_basic.numba_njit(boundscheck=False)
368368
def cumop(x):
369369
out_dtype = x.dtype
370370
if x.shape[axis] < 2:
@@ -382,13 +382,13 @@ Here's an example for the `CumOp`\ `Op`:
382382
else:
383383
if axis is None or ndim == 1:
384384
385-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
385+
@numba_basic.numba_njit()
386386
def cumop(x):
387387
return np.cumprod(x)
388388
389389
else:
390390
391-
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
391+
@numba_basic.numba_njit(boundscheck=False)
392392
def cumop(x):
393393
out_dtype = x.dtype
394394
if x.shape[axis] < 2:

pytensor/link/numba/dispatch/basic.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,23 @@ def global_numba_func(func):
4949
return func
5050

5151

52-
def numba_njit(*args, **kwargs):
52+
def numba_njit(*args, fastmath=None, **kwargs):
5353
kwargs.setdefault("cache", config.numba__cache)
5454
kwargs.setdefault("no_cpython_wrapper", True)
5555
kwargs.setdefault("no_cfunc_wrapper", True)
56+
if fastmath is None:
57+
if config.numba__fastmath:
58+
# Opinionated default on fastmath flags
59+
# https://llvm.org/docs/LangRef.html#fast-math-flags
60+
fastmath = {
61+
"arcp", # Allow Reciprocal
62+
"contract", # Allow floating-point contraction
63+
"afn", # Approximate functions
64+
"reassoc",
65+
"nsz", # no-signed zeros
66+
}
67+
else:
68+
fastmath = False
5669

5770
# Suppress cache warning for internal functions
5871
# We have to add an ansi escape code for optional bold text by numba
@@ -68,9 +81,9 @@ def numba_njit(*args, **kwargs):
6881
)
6982

7083
if len(args) > 0 and callable(args[0]):
71-
return numba.njit(*args[1:], **kwargs)(args[0])
84+
return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0])
7285

73-
return numba.njit(*args, **kwargs)
86+
return numba.njit(*args, fastmath=fastmath, **kwargs)
7487

7588

7689
def numba_vectorize(*args, **kwargs):

pytensor/link/numba/dispatch/blockwise.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
3232
core_op,
3333
node=core_node,
3434
parent_node=node,
35-
fastmath=_jit_options["fastmath"],
3635
**kwargs,
3736
)
3837
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
1-
from collections.abc import Callable
21
from functools import singledispatch
32
from textwrap import dedent, indent
4-
from typing import Any
53

64
import numba
75
import numpy as np
86
from numba.core.extending import overload
97
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
108

11-
from pytensor import config
12-
from pytensor.graph.basic import Apply
139
from pytensor.graph.op import Op
1410
from pytensor.link.numba.dispatch import basic as numba_basic
1511
from pytensor.link.numba.dispatch.basic import (
@@ -124,42 +120,6 @@ def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr):
124120
"""
125121

126122

127-
def create_vectorize_func(
128-
scalar_op_fn: Callable,
129-
node: Apply,
130-
use_signature: bool = False,
131-
identity: Any | None = None,
132-
**kwargs,
133-
) -> Callable:
134-
r"""Create a vectorized Numba function from a `Apply`\s Python function."""
135-
136-
if len(node.outputs) > 1:
137-
raise NotImplementedError(
138-
"Multi-output Elemwise Ops are not supported by the Numba backend"
139-
)
140-
141-
if use_signature:
142-
signature = [create_numba_signature(node, force_scalar=True)]
143-
else:
144-
signature = []
145-
146-
target = (
147-
getattr(node.tag, "numba__vectorize_target", None)
148-
or config.numba__vectorize_target
149-
)
150-
151-
numba_vectorized_fn = numba_basic.numba_vectorize(
152-
signature, identity=identity, target=target, fastmath=config.numba__fastmath
153-
)
154-
155-
py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn)
156-
157-
elemwise_fn = numba_vectorized_fn(scalar_op_fn)
158-
elemwise_fn.py_scalar_func = py_scalar_func
159-
160-
return elemwise_fn
161-
162-
163123
def create_multiaxis_reducer(
164124
scalar_op,
165125
identity,
@@ -320,7 +280,6 @@ def jit_compile_reducer(
320280
res = numba_basic.numba_njit(
321281
*args,
322282
boundscheck=False,
323-
fastmath=config.numba__fastmath,
324283
**kwds,
325284
)(fn)
326285

@@ -354,7 +313,6 @@ def numba_funcify_Elemwise(op, node, **kwargs):
354313
op.scalar_op,
355314
node=scalar_node,
356315
parent_node=node,
357-
fastmath=_jit_options["fastmath"],
358316
**kwargs,
359317
)
360318

@@ -442,13 +400,13 @@ def numba_funcify_Sum(op, node, **kwargs):
442400

443401
if ndim_input == len(axes):
444402
# Slightly faster than `numba_funcify_CAReduce` for this case
445-
@numba_njit(fastmath=config.numba__fastmath)
403+
@numba_njit
446404
def impl_sum(array):
447405
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
448406

449407
elif len(axes) == 0:
450408
# These cases should be removed by rewrites!
451-
@numba_njit(fastmath=config.numba__fastmath)
409+
@numba_njit
452410
def impl_sum(array):
453411
return np.asarray(array, dtype=out_dtype)
454412

@@ -607,9 +565,7 @@ def numba_funcify_Softmax(op, node, **kwargs):
607565
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
608566
)
609567

610-
jit_fn = numba_basic.numba_njit(
611-
boundscheck=False, fastmath=config.numba__fastmath
612-
)
568+
jit_fn = numba_basic.numba_njit(boundscheck=False)
613569
reduce_max = jit_fn(reduce_max_py)
614570
reduce_sum = jit_fn(reduce_sum_py)
615571
else:
@@ -641,9 +597,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
641597
add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True
642598
)
643599

644-
jit_fn = numba_basic.numba_njit(
645-
boundscheck=False, fastmath=config.numba__fastmath
646-
)
600+
jit_fn = numba_basic.numba_njit(boundscheck=False)
647601
reduce_sum = jit_fn(reduce_sum_py)
648602
else:
649603
reduce_sum = np.sum
@@ -681,9 +635,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
681635
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
682636
)
683637

684-
jit_fn = numba_basic.numba_njit(
685-
boundscheck=False, fastmath=config.numba__fastmath
686-
)
638+
jit_fn = numba_basic.numba_njit(boundscheck=False)
687639
reduce_max = jit_fn(reduce_max_py)
688640
reduce_sum = jit_fn(reduce_sum_py)
689641
else:

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numba
55
import numpy as np
66

7-
from pytensor import config
87
from pytensor.graph import Apply
98
from pytensor.link.numba.dispatch import basic as numba_basic
109
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
@@ -50,13 +49,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
5049
if mode == "add":
5150
if axis is None or ndim == 1:
5251

53-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
52+
@numba_basic.numba_njit
5453
def cumop(x):
5554
return np.cumsum(x)
5655

5756
else:
5857

59-
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
58+
@numba_basic.numba_njit(boundscheck=False)
6059
def cumop(x):
6160
out_dtype = x.dtype
6261
if x.shape[axis] < 2:
@@ -74,13 +73,13 @@ def cumop(x):
7473
else:
7574
if axis is None or ndim == 1:
7675

77-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
76+
@numba_basic.numba_njit
7877
def cumop(x):
7978
return np.cumprod(x)
8079

8180
else:
8281

83-
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
82+
@numba_basic.numba_njit(boundscheck=False)
8483
def cumop(x):
8584
out_dtype = x.dtype
8685
if x.shape[axis] < 2:

pytensor/link/numba/dispatch/scalar.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44

5-
from pytensor import config
65
from pytensor.compile.ops import ViewOp
76
from pytensor.graph.basic import Variable
87
from pytensor.link.numba.dispatch import basic as numba_basic
@@ -137,7 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
137136

138137
return numba_basic.numba_njit(
139138
signature,
140-
fastmath=config.numba__fastmath,
141139
# Functions that call a function pointer can't be cached
142140
cache=False,
143141
)(scalar_op_fn)
@@ -177,19 +175,15 @@ def numba_funcify_Add(op, node, **kwargs):
177175
signature = create_numba_signature(node, force_scalar=True)
178176
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
179177

180-
return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
181-
nary_add_fn
182-
)
178+
return numba_basic.numba_njit(signature)(nary_add_fn)
183179

184180

185181
@numba_funcify.register(Mul)
186182
def numba_funcify_Mul(op, node, **kwargs):
187183
signature = create_numba_signature(node, force_scalar=True)
188184
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")
189185

190-
return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
191-
nary_add_fn
192-
)
186+
return numba_basic.numba_njit(signature)(nary_add_fn)
193187

194188

195189
@numba_funcify.register(Cast)
@@ -239,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs):
239233

240234
_ = kwargs.pop("storage_map", None)
241235

242-
composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
236+
composite_fn = numba_basic.numba_njit(signature)(
243237
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
244238
)
245239
return composite_fn
@@ -267,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
267261
return numba_basic.global_numba_func(reciprocal)
268262

269263

270-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
264+
@numba_basic.numba_njit
271265
def sigmoid(x):
272266
return 1 / (1 + np.exp(-x))
273267

@@ -277,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs):
277271
return numba_basic.global_numba_func(sigmoid)
278272

279273

280-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
274+
@numba_basic.numba_njit
281275
def gammaln(x):
282276
return math.lgamma(x)
283277

@@ -287,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs):
287281
return numba_basic.global_numba_func(gammaln)
288282

289283

290-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
284+
@numba_basic.numba_njit
291285
def logp1mexp(x):
292286
if x < np.log(0.5):
293287
return np.log1p(-np.exp(x))
@@ -300,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
300294
return numba_basic.global_numba_func(logp1mexp)
301295

302296

303-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
297+
@numba_basic.numba_njit
304298
def erf(x):
305299
return math.erf(x)
306300

@@ -310,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs):
310304
return numba_basic.global_numba_func(erf)
311305

312306

313-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
307+
@numba_basic.numba_njit
314308
def erfc(x):
315309
return math.erfc(x)
316310

tests/link/numba/test_scalar.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.graph.basic import Constant
1010
from pytensor.graph.fg import FunctionGraph
1111
from pytensor.scalar.basic import Composite
12+
from pytensor.tensor import tensor
1213
from pytensor.tensor.elemwise import Elemwise
1314
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value
1415

@@ -140,3 +141,21 @@ def test_reciprocal(v, dtype):
140141
if not isinstance(i, SharedVariable | Constant)
141142
],
142143
)
144+
145+
146+
@pytest.mark.parametrize("composite", (False, True))
147+
def test_isnan(composite):
148+
# Testing with tensor just to make sure Elemwise does not revert the scalar behavior of fastmath
149+
x = tensor(shape=(2,), dtype="float64")
150+
151+
if composite:
152+
x_scalar = psb.float64()
153+
scalar_out = ~psb.isnan(x_scalar)
154+
out = Elemwise(Composite([x_scalar], [scalar_out]))(x)
155+
else:
156+
out = pt.isnan(x)
157+
158+
compare_numba_and_py(
159+
([x], [out]),
160+
[np.array([1, 0], dtype="float64")],
161+
)

0 commit comments

Comments
 (0)