Skip to content

Commit 3468613

Browse files
committed
Saner defaults
1 parent c9ca9fa commit 3468613

File tree

11 files changed

+378
-243
lines changed

11 files changed

+378
-243
lines changed

pytensor/link/numba/cache.py

Lines changed: 169 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1+
import warnings
12
import weakref
3+
from collections.abc import Callable
4+
from functools import singledispatch, wraps
25
from hashlib import sha256
36
from pathlib import Path
7+
from pickle import dumps
8+
from tempfile import NamedTemporaryFile
9+
from typing import Any
410

511
from numba.core.caching import CacheImpl, _CacheLocator
612

713
from pytensor import config
814
from pytensor.graph.basic import Apply
15+
from pytensor.link.numba.compile import numba_funcify, numba_njit
916

1017

1118
NUMBA_PYTENSOR_CACHE_ENABLED = True
@@ -19,8 +26,6 @@ def __init__(self, py_func, py_file, hash):
1926
self._py_func = py_func
2027
self._py_file = py_file
2128
self._hash = hash
22-
# src_hash = hash(pytensor_loader._module_sources[self._py_file])
23-
# self._hash = hash((src_hash, py_file, pytensor.__version__))
2429

2530
def ensure_cache_path(self):
2631
pass
@@ -74,3 +79,165 @@ def cache_node_key(node: Apply, extra_key="") -> str:
7479
),
7580
).encode()
7681
).hexdigest()
82+
83+
84+
@singledispatch
85+
def numba_funcify_default_op_cache_key(
86+
op, node=None, **kwargs
87+
) -> Callable | tuple[Callable, Any]:
88+
"""Funcify an Op and implement a default cache key.
89+
90+
The default cache key is based on the op class and its properties.
91+
It does not take into account the node inputs or other context.
92+
Note that numba will use the array dtypes, rank and layout as part of the cache key,
93+
but not the static shape or constant values.
94+
If the funcify implementation exploits this information, then this method should not be used.
95+
Instead dispatch directly on `numba_funcify_and_cache_key` (or just numba_funcify)
96+
which won't use any cache key.
97+
"""
98+
# Default cache key of None which means "don't try to do directly cache this function"
99+
raise NotImplementedError()
100+
101+
102+
def register_funcify_default_op_cache_key(op_type):
103+
"""Register a funcify implementation for both cache and non-cache versions."""
104+
105+
def decorator(dispatch_func):
106+
# Register with the cache key dispatcher
107+
numba_funcify_default_op_cache_key.register(op_type)(dispatch_func)
108+
109+
# Create a wrapper for the non-cache dispatcher
110+
@wraps(dispatch_func)
111+
def dispatch_func_wrapper(*args, **kwargs):
112+
func, key = dispatch_func(*args, **kwargs)
113+
# Discard the key for the non-cache version
114+
return func
115+
116+
# Register the wrapper with the non-cache dispatcher
117+
numba_funcify.register(op_type)(dispatch_func_wrapper)
118+
119+
return dispatch_func
120+
121+
return decorator
122+
123+
124+
@singledispatch
125+
def numba_funcify_and_cache_key(op, node=None, **kwargs) -> tuple[Callable, str | None]:
126+
# Default cache key of None which means "don't try to do directly cache this function"
127+
if hasattr(op, "_props"):
128+
try:
129+
func_and_salt = numba_funcify_default_op_cache_key(op, node=node, **kwargs)
130+
except NotImplementedError:
131+
pass
132+
else:
133+
if isinstance(func_and_salt, tuple):
134+
func, salt = func_and_salt
135+
else:
136+
func, salt = func_and_salt, "0"
137+
props_dict = op._props_dict()
138+
if not props_dict:
139+
# Simple op, just use the type string as key
140+
key_bytes = str((type(op), salt)).encode()
141+
else:
142+
# Simple props, can use string representation of props as key
143+
simple_types = (str, bool, int, type(None), float)
144+
container_types = (tuple, frozenset)
145+
if all(
146+
isinstance(v, simple_types)
147+
or (
148+
isinstance(v, container_types)
149+
and all(isinstance(i, simple_types) for i in v)
150+
)
151+
for v in props_dict.values()
152+
):
153+
key_bytes = str(
154+
(type(op), tuple(props_dict.items()), salt)
155+
).encode()
156+
else:
157+
# Complex props, use pickle to serialize them
158+
key_bytes = dumps((str(type(op)), tuple(props_dict.items()), salt))
159+
return func, sha256(key_bytes).hexdigest()
160+
161+
# Fallback
162+
return numba_funcify(op, node=node, **kwargs), None
163+
164+
165+
def register_funcify_and_cache_key(op_type):
166+
"""Register a funcify implementation for both cache and non-cache versions."""
167+
168+
def decorator(dispatch_func):
169+
# Register with the cache key dispatcher
170+
numba_funcify_and_cache_key.register(op_type)(dispatch_func)
171+
172+
# Create a wrapper for the non-cache dispatcher
173+
@wraps(dispatch_func)
174+
def dispatch_func_wrapper(*args, **kwargs):
175+
func, key = dispatch_func(*args, **kwargs)
176+
# Discard the key for the non-cache version
177+
return func
178+
179+
# Register the wrapper with the non-cache dispatcher
180+
numba_funcify.register(op_type)(dispatch_func_wrapper)
181+
182+
return dispatch_func_wrapper
183+
184+
return decorator
185+
186+
187+
def numba_njit_and_cache(op, *args, **kwargs):
188+
jitable_func, key = numba_funcify_and_cache_key(op, *args, **kwargs)
189+
190+
if key is not None:
191+
# To force numba to use our cache, we must compile the function so that any closure
192+
# becomes a global variable...
193+
op_name = op.__class__.__name__
194+
cached_func = compile_numba_function_src(
195+
src=f"def {op_name}(*args): return jitable_func(*args)",
196+
function_name=op_name,
197+
global_env=globals() | {"jitable_func": jitable_func},
198+
cache_key=key,
199+
)
200+
return numba_njit(cached_func, final_function=True, cache=True), key
201+
else:
202+
if config.numba__cache and config.compiler_verbose:
203+
warnings.warn(
204+
f"Custom numba cache disabled for {op} of type {type(op)}. "
205+
f"Even if the function is cached by numba, larger graphs using this function cannot be cached.\n"
206+
"To enable custom caching, register a numba_funcify_and_cache_key implementation for this Op, with a proper cache key."
207+
)
208+
209+
return numba_njit(
210+
lambda *args: jitable_func(*args), final_function=True, cache=False
211+
), None
212+
213+
214+
def compile_numba_function_src(
215+
src: str,
216+
function_name: str,
217+
global_env: dict[Any, Any] | None = None,
218+
local_env: dict[Any, Any] | None = None,
219+
store_to_disk: bool = False,
220+
cache_key: str | None = None,
221+
) -> Callable:
222+
if store_to_disk:
223+
with NamedTemporaryFile(delete=False) as f:
224+
filename = f.name
225+
f.write(src.encode())
226+
else:
227+
filename = "<string>"
228+
229+
if global_env is None:
230+
global_env = {}
231+
232+
if local_env is None:
233+
local_env = {}
234+
235+
mod_code = compile(src, filename, mode="exec")
236+
exec(mod_code, global_env, local_env)
237+
238+
res = local_env[function_name]
239+
res.__source__ = src # type: ignore
240+
241+
if cache_key is not None:
242+
CACHED_SRC_FUNCTIONS[res] = cache_key
243+
return res

pytensor/link/numba/compile.py

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import warnings
22
from collections.abc import Callable
3-
from tempfile import NamedTemporaryFile
4-
from typing import Any
3+
from functools import singledispatch
54

65
import numba
76
import numpy as np
@@ -11,8 +10,6 @@
1110

1211
from pytensor import config
1312
from pytensor.graph import Apply, FunctionGraph, Type
14-
from pytensor.link.numba.cache import CACHED_SRC_FUNCTIONS
15-
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
1613
from pytensor.scalar import ScalarType
1714
from pytensor.sparse import SparseTensorType
1815
from pytensor.tensor import TensorType
@@ -59,36 +56,17 @@ def numba_njit(*args, fastmath=None, final_function: bool = False, **kwargs):
5956
return func(*args, fastmath=fastmath, **kwargs)
6057

6158

62-
def compile_and_cache_numba_function_src(
63-
src: str,
64-
function_name: str,
65-
global_env: dict[Any, Any] | None = None,
66-
local_env: dict[Any, Any] | None = None,
67-
store_to_disk: bool = False,
68-
cache_key: str | None = None,
69-
) -> Callable:
70-
if store_to_disk:
71-
with NamedTemporaryFile(delete=False) as f:
72-
filename = f.name
73-
f.write(src.encode())
74-
else:
75-
filename = "<string>"
76-
77-
if global_env is None:
78-
global_env = {}
79-
80-
if local_env is None:
81-
local_env = {}
82-
83-
mod_code = compile(src, filename, mode="exec")
84-
exec(mod_code, global_env, local_env)
59+
@singledispatch
60+
def numba_funcify(
61+
typ, node=None, storage_map=None, **kwargs
62+
) -> Callable | tuple[Callable, str | int | None]:
63+
"""Generate a numba function for a given op and apply node (or Fgraph).
8564
86-
res = local_env[function_name]
87-
res.__source__ = src # type: ignore
88-
89-
if cache_key is not None:
90-
CACHED_SRC_FUNCTIONS[res] = cache_key
91-
return res
65+
The resulting function will usually use the `no_cpython_wrapper`
66+
argument in numba, so it can not be called directly from python,
67+
but only from other jit functions.
68+
"""
69+
raise NotImplementedError(f"Numba funcify not implemented for type {typ}")
9270

9371

9472
def get_numba_type(
@@ -124,6 +102,8 @@ def get_numba_type(
124102
numba_dtype = numba.from_dtype(dtype)
125103
return numba_dtype
126104
elif isinstance(pytensor_type, SparseTensorType):
105+
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
106+
127107
dtype = pytensor_type.numpy_dtype
128108
numba_dtype = numba.from_dtype(dtype)
129109
if pytensor_type.format == "csr":

0 commit comments

Comments
 (0)