Skip to content

Commit ac8daa1

Browse files
committed
Refactor numba compiling functionality out of dispatch/basic.py
1 parent 0b2dcbe commit ac8daa1

File tree

20 files changed

+435
-466
lines changed

20 files changed

+435
-466
lines changed

pytensor/link/numba/compile.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import warnings
2+
3+
import numba
4+
import numpy as np
5+
from numba import NumbaWarning
6+
from numba import njit as _njit
7+
from numba.core.extending import register_jitable
8+
9+
from pytensor import config
10+
from pytensor.graph import Apply, FunctionGraph, Type
11+
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
12+
from pytensor.scalar import ScalarType
13+
from pytensor.sparse import SparseTensorType
14+
from pytensor.tensor import TensorType
15+
16+
17+
def numba_njit(*args, fastmath=None, final_function: bool = False, **kwargs):
18+
if fastmath is None:
19+
if config.numba__fastmath:
20+
# Opinionated default on fastmath flags
21+
# https://llvm.org/docs/LangRef.html#fast-math-flags
22+
fastmath = {
23+
"arcp", # Allow Reciprocal
24+
"contract", # Allow floating-point contraction
25+
"afn", # Approximate functions
26+
"reassoc",
27+
"nsz", # no-signed zeros
28+
}
29+
else:
30+
fastmath = False
31+
32+
if final_function:
33+
kwargs.setdefault("cache", True)
34+
else:
35+
kwargs.setdefault("no_cpython_wrapper", True)
36+
kwargs.setdefault("no_cfunc_wrapper", True)
37+
38+
# Suppress cache warning for internal functions
39+
# We have to add an ansi escape code for optional bold text by numba
40+
warnings.filterwarnings(
41+
"ignore",
42+
message=(
43+
"(\x1b\\[1m)*" # ansi escape code for bold text
44+
"Cannot cache compiled function "
45+
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve|lu_factor)" '
46+
"as it uses dynamic globals"
47+
),
48+
category=NumbaWarning,
49+
)
50+
51+
func = _njit if final_function else register_jitable
52+
if len(args) > 0 and callable(args[0]):
53+
return func(*args[1:], fastmath=fastmath, **kwargs)(args[0])
54+
else:
55+
return func(*args, fastmath=fastmath, **kwargs)
56+
57+
58+
def get_numba_type(
59+
pytensor_type: Type,
60+
layout: str = "A",
61+
force_scalar: bool = False,
62+
reduce_to_scalar: bool = False,
63+
) -> numba.types.Type:
64+
r"""Create a Numba type object for a :class:`Type`.
65+
66+
Parameters
67+
----------
68+
pytensor_type
69+
The :class:`Type` to convert.
70+
layout
71+
The :class:`numpy.ndarray` layout to use.
72+
force_scalar
73+
Ignore dimension information and return the corresponding Numba scalar types.
74+
reduce_to_scalar
75+
Return Numba scalars for zero dimensional :class:`TensorType`\s.
76+
"""
77+
78+
if isinstance(pytensor_type, TensorType):
79+
dtype = pytensor_type.numpy_dtype
80+
numba_dtype = numba.from_dtype(dtype)
81+
if force_scalar or (
82+
reduce_to_scalar and getattr(pytensor_type, "ndim", None) == 0
83+
):
84+
return numba_dtype
85+
return numba.types.Array(numba_dtype, pytensor_type.ndim, layout)
86+
elif isinstance(pytensor_type, ScalarType):
87+
dtype = np.dtype(pytensor_type.dtype)
88+
numba_dtype = numba.from_dtype(dtype)
89+
return numba_dtype
90+
elif isinstance(pytensor_type, SparseTensorType):
91+
dtype = pytensor_type.numpy_dtype
92+
numba_dtype = numba.from_dtype(dtype)
93+
if pytensor_type.format == "csr":
94+
return CSRMatrixType(numba_dtype)
95+
if pytensor_type.format == "csc":
96+
return CSCMatrixType(numba_dtype)
97+
98+
raise NotImplementedError()
99+
else:
100+
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
101+
102+
103+
def create_numba_signature(
104+
node_or_fgraph: FunctionGraph | Apply,
105+
force_scalar: bool = False,
106+
reduce_to_scalar: bool = False,
107+
) -> numba.types.Type:
108+
"""Create a Numba type for the signature of an `Apply` node or `FunctionGraph`."""
109+
input_types = [
110+
get_numba_type(
111+
inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
112+
)
113+
for inp in node_or_fgraph.inputs
114+
]
115+
116+
output_types = [
117+
get_numba_type(
118+
out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
119+
)
120+
for out in node_or_fgraph.outputs
121+
]
122+
123+
if len(output_types) > 1:
124+
return numba.types.Tuple(output_types)(*input_types)
125+
elif len(output_types) == 1:
126+
return output_types[0](*input_types)
127+
else:
128+
return numba.types.void(*input_types)
129+
130+
131+
def create_tuple_creator(f, n):
132+
"""Construct a compile-time ``tuple``-comprehension-like loop.
133+
134+
See https://github.com/numba/numba/issues/2771#issuecomment-414358902
135+
"""
136+
assert n > 0
137+
138+
f = numba_njit(f)
139+
140+
@numba_njit
141+
def creator(args):
142+
return (f(0, *args),)
143+
144+
for i in range(1, n):
145+
146+
@numba_njit
147+
def creator(args, creator=creator, i=i):
148+
return (*creator(args), f(i, *args))
149+
150+
return numba_njit(lambda *args: creator(args))
151+
152+
153+
def create_tuple_string(x):
154+
args = ", ".join(x + ([""] if len(x) == 1 else []))
155+
return f"({args})"
156+
157+
158+
def create_arg_string(x):
159+
args = ", ".join(x)
160+
return args

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytensor.link.numba.dispatch.random
1010
import pytensor.link.numba.dispatch.scan
1111
import pytensor.link.numba.dispatch.scalar
12+
import pytensor.link.numba.dispatch.shape
1213
import pytensor.link.numba.dispatch.signal
1314
import pytensor.link.numba.dispatch.slinalg
1415
import pytensor.link.numba.dispatch.sparse

0 commit comments

Comments
 (0)