Skip to content

Commit 2ea42ab

Browse files
committed
Use more specific Numba fastmath flags everywhere
1 parent c59ba2e commit 2ea42ab

File tree

7 files changed

+48
-87
lines changed

7 files changed

+48
-87
lines changed

doc/extending/creating_a_numba_jax_op.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,13 @@ Here's an example for the `CumOp`\ `Op`:
358358
if mode == "add":
359359
if axis is None or ndim == 1:
360360
361-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
361+
@numba_basic.numba_njit()
362362
def cumop(x):
363363
return np.cumsum(x)
364364
365365
else:
366366
367-
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
367+
@numba_basic.numba_njit(boundscheck=False)
368368
def cumop(x):
369369
out_dtype = x.dtype
370370
if x.shape[axis] < 2:
@@ -382,13 +382,13 @@ Here's an example for the `CumOp`\ `Op`:
382382
else:
383383
if axis is None or ndim == 1:
384384
385-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
385+
@numba_basic.numba_njit()
386386
def cumop(x):
387387
return np.cumprod(x)
388388
389389
else:
390390
391-
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
391+
@numba_basic.numba_njit(boundscheck=False)
392392
def cumop(x):
393393
out_dtype = x.dtype
394394
if x.shape[axis] < 2:

pytensor/link/numba/dispatch/basic.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,22 @@ def global_numba_func(func):
4949
return func
5050

5151

52-
def numba_njit(*args, **kwargs):
52+
def numba_njit(*args, fastmath=None, **kwargs):
5353
kwargs.setdefault("cache", config.numba__cache)
5454
kwargs.setdefault("no_cpython_wrapper", True)
5555
kwargs.setdefault("no_cfunc_wrapper", True)
56+
if fastmath is None and config.numba__fastmath:
57+
# Opinionated default on fastmath flags
58+
# https://llvm.org/docs/LangRef.html#fast-math-flags
59+
fastmath = {
60+
"arcp", # Allow Reciprocal
61+
"contract", # Allow floating-point contraction
62+
"afn", # Approximate functions
63+
"reassoc",
64+
"nsz", # no-signed zeros
65+
}
66+
else:
67+
fastmath = False
5668

5769
# Suppress cache warning for internal functions
5870
# We have to add an ansi escape code for optional bold text by numba
@@ -68,9 +80,9 @@ def numba_njit(*args, **kwargs):
6880
)
6981

7082
if len(args) > 0 and callable(args[0]):
71-
return numba.njit(*args[1:], **kwargs)(args[0])
83+
return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0])
7284

73-
return numba.njit(*args, **kwargs)
85+
return numba.njit(*args, fastmath=fastmath, **kwargs)
7486

7587

7688
def numba_vectorize(*args, **kwargs):

pytensor/link/numba/dispatch/blockwise.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
3232
core_op,
3333
node=core_node,
3434
parent_node=node,
35-
fastmath=_jit_options["fastmath"],
3635
**kwargs,
3736
)
3837
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
1-
from collections.abc import Callable
21
from functools import singledispatch
32
from textwrap import dedent, indent
4-
from typing import Any
53

64
import numba
75
import numpy as np
86
from numba.core.extending import overload
97
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
108

11-
from pytensor import config
12-
from pytensor.graph.basic import Apply
139
from pytensor.graph.op import Op
1410
from pytensor.link.numba.dispatch import basic as numba_basic
1511
from pytensor.link.numba.dispatch.basic import (
@@ -124,42 +120,6 @@ def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr):
124120
"""
125121

126122

127-
def create_vectorize_func(
128-
scalar_op_fn: Callable,
129-
node: Apply,
130-
use_signature: bool = False,
131-
identity: Any | None = None,
132-
**kwargs,
133-
) -> Callable:
134-
r"""Create a vectorized Numba function from a `Apply`\s Python function."""
135-
136-
if len(node.outputs) > 1:
137-
raise NotImplementedError(
138-
"Multi-output Elemwise Ops are not supported by the Numba backend"
139-
)
140-
141-
if use_signature:
142-
signature = [create_numba_signature(node, force_scalar=True)]
143-
else:
144-
signature = []
145-
146-
target = (
147-
getattr(node.tag, "numba__vectorize_target", None)
148-
or config.numba__vectorize_target
149-
)
150-
151-
numba_vectorized_fn = numba_basic.numba_vectorize(
152-
signature, identity=identity, target=target, fastmath=config.numba__fastmath
153-
)
154-
155-
py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn)
156-
157-
elemwise_fn = numba_vectorized_fn(scalar_op_fn)
158-
elemwise_fn.py_scalar_func = py_scalar_func
159-
160-
return elemwise_fn
161-
162-
163123
def create_multiaxis_reducer(
164124
scalar_op,
165125
identity,
@@ -320,7 +280,6 @@ def jit_compile_reducer(
320280
res = numba_basic.numba_njit(
321281
*args,
322282
boundscheck=False,
323-
fastmath=config.numba__fastmath,
324283
**kwds,
325284
)(fn)
326285

@@ -354,7 +313,6 @@ def numba_funcify_Elemwise(op, node, **kwargs):
354313
op.scalar_op,
355314
node=scalar_node,
356315
parent_node=node,
357-
fastmath=_jit_options["fastmath"],
358316
**kwargs,
359317
)
360318

@@ -442,13 +400,13 @@ def numba_funcify_Sum(op, node, **kwargs):
442400

443401
if ndim_input == len(axes):
444402
# Slightly faster than `numba_funcify_CAReduce` for this case
445-
@numba_njit(fastmath=config.numba__fastmath)
403+
@numba_njit
446404
def impl_sum(array):
447405
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
448406

449407
elif len(axes) == 0:
450408
# These cases should be removed by rewrites!
451-
@numba_njit(fastmath=config.numba__fastmath)
409+
@numba_njit
452410
def impl_sum(array):
453411
return np.asarray(array, dtype=out_dtype)
454412

@@ -607,9 +565,7 @@ def numba_funcify_Softmax(op, node, **kwargs):
607565
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
608566
)
609567

610-
jit_fn = numba_basic.numba_njit(
611-
boundscheck=False, fastmath=config.numba__fastmath
612-
)
568+
jit_fn = numba_basic.numba_njit(boundscheck=False)
613569
reduce_max = jit_fn(reduce_max_py)
614570
reduce_sum = jit_fn(reduce_sum_py)
615571
else:
@@ -641,9 +597,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
641597
add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True
642598
)
643599

644-
jit_fn = numba_basic.numba_njit(
645-
boundscheck=False, fastmath=config.numba__fastmath
646-
)
600+
jit_fn = numba_basic.numba_njit(boundscheck=False)
647601
reduce_sum = jit_fn(reduce_sum_py)
648602
else:
649603
reduce_sum = np.sum
@@ -681,9 +635,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
681635
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
682636
)
683637

684-
jit_fn = numba_basic.numba_njit(
685-
boundscheck=False, fastmath=config.numba__fastmath
686-
)
638+
jit_fn = numba_basic.numba_njit(boundscheck=False)
687639
reduce_max = jit_fn(reduce_max_py)
688640
reduce_sum = jit_fn(reduce_sum_py)
689641
else:

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numba
55
import numpy as np
66

7-
from pytensor import config
87
from pytensor.graph import Apply
98
from pytensor.link.numba.dispatch import basic as numba_basic
109
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
@@ -50,13 +49,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
5049
if mode == "add":
5150
if axis is None or ndim == 1:
5251

53-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
52+
@numba_basic.numba_njit
5453
def cumop(x):
5554
return np.cumsum(x)
5655

5756
else:
5857

59-
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
58+
@numba_basic.numba_njit(boundscheck=False)
6059
def cumop(x):
6160
out_dtype = x.dtype
6261
if x.shape[axis] < 2:
@@ -74,13 +73,13 @@ def cumop(x):
7473
else:
7574
if axis is None or ndim == 1:
7675

77-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
76+
@numba_basic.numba_njit
7877
def cumop(x):
7978
return np.cumprod(x)
8079

8180
else:
8281

83-
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
82+
@numba_basic.numba_njit(boundscheck=False)
8483
def cumop(x):
8584
out_dtype = x.dtype
8685
if x.shape[axis] < 2:

pytensor/link/numba/dispatch/scalar.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44

5-
from pytensor import config
65
from pytensor.compile.ops import ViewOp
76
from pytensor.graph.basic import Variable
87
from pytensor.link.numba.dispatch import basic as numba_basic
@@ -23,7 +22,6 @@
2322
Clip,
2423
Composite,
2524
Identity,
26-
IsNan,
2725
Mul,
2826
Reciprocal,
2927
ScalarOp,
@@ -138,8 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
138136

139137
return numba_basic.numba_njit(
140138
signature,
141-
# numba always returns False if fastmath=True # https://github.com/numba/numba/issues/9383
142-
fastmath=False if isinstance(op, IsNan) else config.numba__fastmath,
143139
# Functions that call a function pointer can't be cached
144140
cache=False,
145141
)(scalar_op_fn)
@@ -179,19 +175,15 @@ def numba_funcify_Add(op, node, **kwargs):
179175
signature = create_numba_signature(node, force_scalar=True)
180176
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
181177

182-
return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
183-
nary_add_fn
184-
)
178+
return numba_basic.numba_njit(signature)(nary_add_fn)
185179

186180

187181
@numba_funcify.register(Mul)
188182
def numba_funcify_Mul(op, node, **kwargs):
189183
signature = create_numba_signature(node, force_scalar=True)
190184
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")
191185

192-
return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
193-
nary_add_fn
194-
)
186+
return numba_basic.numba_njit(signature)(nary_add_fn)
195187

196188

197189
@numba_funcify.register(Cast)
@@ -241,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs):
241233

242234
_ = kwargs.pop("storage_map", None)
243235

244-
composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
236+
composite_fn = numba_basic.numba_njit(signature)(
245237
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
246238
)
247239
return composite_fn
@@ -269,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
269261
return numba_basic.global_numba_func(reciprocal)
270262

271263

272-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
264+
@numba_basic.numba_njit
273265
def sigmoid(x):
274266
return 1 / (1 + np.exp(-x))
275267

@@ -279,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs):
279271
return numba_basic.global_numba_func(sigmoid)
280272

281273

282-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
274+
@numba_basic.numba_njit
283275
def gammaln(x):
284276
return math.lgamma(x)
285277

@@ -289,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs):
289281
return numba_basic.global_numba_func(gammaln)
290282

291283

292-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
284+
@numba_basic.numba_njit
293285
def logp1mexp(x):
294286
if x < np.log(0.5):
295287
return np.log1p(-np.exp(x))
@@ -302,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
302294
return numba_basic.global_numba_func(logp1mexp)
303295

304296

305-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
297+
@numba_basic.numba_njit
306298
def erf(x):
307299
return math.erf(x)
308300

@@ -312,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs):
312304
return numba_basic.global_numba_func(erf)
313305

314306

315-
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
307+
@numba_basic.numba_njit
316308
def erfc(x):
317309
return math.erfc(x)
318310

tests/link/numba/test_scalar.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,19 @@ def test_reciprocal(v, dtype):
143143
)
144144

145145

146-
@pytest.mark.parametrize("dtype", ("complex64", "float64", "float32"))
147-
def test_isnan(dtype):
146+
@pytest.mark.parametrize("composite", (False, True))
147+
def test_isnan(composite):
148148
# Testing with tensor just to make sure Elemwise does not revert the scalar behavior of fastmath
149-
x = tensor(shape=(2,), dtype=dtype)
150-
out = pt.isnan(x)
149+
x = tensor(shape=(2,), dtype="float64")
150+
151+
if composite:
152+
x_scalar = psb.float64()
153+
scalar_out = ~psb.isnan(x_scalar)
154+
out = Elemwise(Composite([x_scalar], [scalar_out]))(x)
155+
else:
156+
out = pt.isnan(x)
157+
151158
compare_numba_and_py(
152159
([x], [out]),
153-
[np.array([1, 0], dtype=dtype)],
160+
[np.array([1, 0], dtype="float64")],
154161
)

0 commit comments

Comments
 (0)