Skip to content

Commit c9ca9fa

Browse files
committed
Control caching of numba functions
1 parent ac8daa1 commit c9ca9fa

File tree

18 files changed

+386
-120
lines changed

18 files changed

+386
-120
lines changed

pytensor/compile/mode.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"jax": JAXLinker(),
5151
"pytorch": PytorchLinker(),
5252
"numba": NumbaLinker(),
53+
"numba_vm": NumbaLinker(vm=True),
5354
}
5455

5556

@@ -351,6 +352,11 @@ def __setstate__(self, state):
351352
optimizer = predefined_optimizers[optimizer]
352353
if isinstance(optimizer, RewriteDatabaseQuery):
353354
self.provided_optimizer = optimizer
355+
356+
# Force numba-required rewrites if using NumbaLinker
357+
if isinstance(linker, NumbaLinker):
358+
optimizer = optimizer.including("numba")
359+
354360
self._optimizer = optimizer
355361
self.call_time = 0
356362
self.fn_time = 0
@@ -475,6 +481,11 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
475481
),
476482
)
477483

484+
NUMBA_VM = Mode(
485+
"numba_vm",
486+
NUMBA._optimizer,
487+
)
488+
478489
JAX = Mode(
479490
"jax",
480491
RewriteDatabaseQuery(
@@ -515,6 +526,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
515526
"C_VM": C_VM,
516527
"JAX": JAX,
517528
"NUMBA": NUMBA,
529+
"NUMBA_VM": NUMBA_VM,
518530
"PYTORCH": PYTORCH,
519531
}
520532

@@ -579,6 +591,7 @@ def register_mode(name, mode):
579591
Add a `Mode` which can be referred to by `name` in `function`.
580592
581593
"""
594+
# TODO: Remove me
582595
if name in predefined_modes:
583596
raise ValueError(f"Mode name already taken: {name}")
584597
predefined_modes[name] = mode

pytensor/configdefaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def add_compile_configvars():
379379
"cvm_nogc",
380380
"jax",
381381
"numba",
382+
"numba_vm",
382383
]
383384
else:
384385
# g++ is not present or the user disabled it,

pytensor/link/numba/cache.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import weakref
2+
from hashlib import sha256
3+
from pathlib import Path
4+
5+
from numba.core.caching import CacheImpl, _CacheLocator
6+
7+
from pytensor import config
8+
from pytensor.graph.basic import Apply
9+
10+
11+
NUMBA_PYTENSOR_CACHE_ENABLED = True
12+
NUMBA_CACHE_PATH = config.base_compiledir / "numba"
13+
NUMBA_CACHE_PATH.mkdir(exist_ok=True)
14+
CACHED_SRC_FUNCTIONS = weakref.WeakKeyDictionary()
15+
16+
17+
class NumbaPyTensorCacheLocator(_CacheLocator):
18+
def __init__(self, py_func, py_file, hash):
19+
self._py_func = py_func
20+
self._py_file = py_file
21+
self._hash = hash
22+
# src_hash = hash(pytensor_loader._module_sources[self._py_file])
23+
# self._hash = hash((src_hash, py_file, pytensor.__version__))
24+
25+
def ensure_cache_path(self):
26+
pass
27+
28+
def get_cache_path(self):
29+
"""
30+
Return the directory the function is cached in.
31+
"""
32+
return NUMBA_CACHE_PATH
33+
34+
def get_source_stamp(self):
35+
"""
36+
Get a timestamp representing the source code's freshness.
37+
Can return any picklable Python object.
38+
"""
39+
return 0
40+
41+
def get_disambiguator(self):
42+
"""
43+
Get a string disambiguator for this locator's function.
44+
It should allow disambiguating different but similarly-named functions.
45+
"""
46+
return self._hash
47+
48+
@classmethod
49+
def from_function(cls, py_func, py_file):
50+
"""
51+
Create a locator instance for the given function located in the given file.
52+
"""
53+
# py_file = Path(py_file).parent
54+
# if py_file == (config.base_compiledir / "numba"):
55+
if NUMBA_PYTENSOR_CACHE_ENABLED and py_func in CACHED_SRC_FUNCTIONS:
56+
# print(f"Applies to {py_file}")
57+
return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func])
58+
59+
60+
CacheImpl._locator_classes.insert(0, NumbaPyTensorCacheLocator)
61+
62+
63+
def cache_node_key(node: Apply, extra_key="") -> str:
64+
op = node.op
65+
return sha256(
66+
str(
67+
(
68+
# Op signature
69+
(type(op), op._props_dict() if hasattr(op, "_props_dict") else ""),
70+
# Node signature
71+
tuple((type(inp_type := inp.type), inp_type) for inp in node.inputs),
72+
# Extra key given by the caller
73+
extra_key,
74+
),
75+
).encode()
76+
).hexdigest()

pytensor/link/numba/compile.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import warnings
2+
from collections.abc import Callable
3+
from tempfile import NamedTemporaryFile
4+
from typing import Any
25

36
import numba
47
import numpy as np
@@ -8,6 +11,7 @@
811

912
from pytensor import config
1013
from pytensor.graph import Apply, FunctionGraph, Type
14+
from pytensor.link.numba.cache import CACHED_SRC_FUNCTIONS
1115
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
1216
from pytensor.scalar import ScalarType
1317
from pytensor.sparse import SparseTensorType
@@ -55,6 +59,38 @@ def numba_njit(*args, fastmath=None, final_function: bool = False, **kwargs):
5559
return func(*args, fastmath=fastmath, **kwargs)
5660

5761

62+
def compile_and_cache_numba_function_src(
63+
src: str,
64+
function_name: str,
65+
global_env: dict[Any, Any] | None = None,
66+
local_env: dict[Any, Any] | None = None,
67+
store_to_disk: bool = False,
68+
cache_key: str | None = None,
69+
) -> Callable:
70+
if store_to_disk:
71+
with NamedTemporaryFile(delete=False) as f:
72+
filename = f.name
73+
f.write(src.encode())
74+
else:
75+
filename = "<string>"
76+
77+
if global_env is None:
78+
global_env = {}
79+
80+
if local_env is None:
81+
local_env = {}
82+
83+
mod_code = compile(src, filename, mode="exec")
84+
exec(mod_code, global_env, local_env)
85+
86+
res = local_env[function_name]
87+
res.__source__ = src # type: ignore
88+
89+
if cache_key is not None:
90+
CACHED_SRC_FUNCTIONS[res] = cache_key
91+
return res
92+
93+
5894
def get_numba_type(
5995
pytensor_type: Type,
6096
layout: str = "A",

pytensor/link/numba/dispatch/basic.py

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import operator
22
import sys
33
import warnings
4+
from collections.abc import Callable
45
from functools import singledispatch
56

67
import numba
@@ -18,7 +19,11 @@
1819
from pytensor.compile.ops import DeepCopyOp
1920
from pytensor.graph.fg import FunctionGraph
2021
from pytensor.ifelse import IfElse
22+
from pytensor.link.numba.cache import (
23+
cache_node_key,
24+
)
2125
from pytensor.link.numba.compile import (
26+
compile_and_cache_numba_function_src,
2227
get_numba_type,
2328
numba_njit,
2429
)
@@ -208,20 +213,80 @@ def perform(*inputs):
208213
ret = py_perform_return(inputs)
209214
return ret
210215

211-
return perform
216+
# Assume we can't cache python functions
217+
return perform, None
212218

213219

214220
@singledispatch
215-
def numba_funcify(op, node=None, storage_map=None, **kwargs):
221+
def numba_funcify(
222+
op, node=None, storage_map=None, **kwargs
223+
) -> Callable | tuple[Callable, str | int | None]:
216224
"""Generate a numba function for a given op and apply node.
217225
218226
The resulting function will usually use the `no_cpython_wrapper`
219227
argument in numba, so it can not be called directly from python,
220228
but only from other jit functions.
229+
230+
Optionally, the function can return a key that can be used to provide
231+
extra caching context or to disable caching (by returning `None`).
232+
When nothing is returned, PyTensor will assume the function can be cached
233+
based on the op and node signature alone.
221234
"""
222235
return generate_fallback_impl(op, node, storage_map, **kwargs)
223236

224237

238+
def numba_funcify_njit(op, node, **kwargs):
239+
jitable_func_and_key = numba_funcify(op, node=node, **kwargs)
240+
241+
match jitable_func_and_key:
242+
case Callable():
243+
jitable_func = jitable_func_and_key
244+
key = cache_node_key(node)
245+
case (Callable(), str() | int()):
246+
jitable_func, funcify_key = jitable_func_and_key
247+
key = cache_node_key(node, funcify_key)
248+
case (Callable(), None):
249+
# We were explicitly told by the dispatch not to try and cache this function
250+
jitable_func, key = jitable_func_and_key
251+
case _:
252+
raise TypeError(
253+
f"numpy_funcify should return a callable or a (callable, key) pair, got {jitable_func_and_key}"
254+
)
255+
256+
if key is not None:
257+
# To force numba to use our cache, we must compile the function so that any closure
258+
# becomes a global variable...
259+
op_name = op.__class__.__name__
260+
cached_func = compile_and_cache_numba_function_src(
261+
src=f"def {op_name}(*args): return jitable_func(*args)",
262+
function_name=op_name,
263+
global_env=globals() | dict(jitable_func=jitable_func),
264+
cache_key=key,
265+
)
266+
return numba_njit(cached_func, final_function=True, cache=True)
267+
else:
268+
return numba_njit(
269+
lambda *args: jitable_func(*args), final_function=True, cache=False
270+
)
271+
272+
273+
@numba_funcify.register(FunctionGraph)
274+
def numba_funcify_FunctionGraph(
275+
fgraph,
276+
node=None,
277+
fgraph_name="numba_funcified_fgraph",
278+
**kwargs,
279+
):
280+
# TODO: Create hash key for whole graph
281+
return fgraph_to_python(
282+
fgraph,
283+
op_conversion_fn=numba_funcify_njit,
284+
type_conversion_fn=numba_typify,
285+
fgraph_name=fgraph_name,
286+
**kwargs,
287+
)
288+
289+
225290
@numba_funcify.register(OpFromGraph)
226291
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
227292
_ = kwargs.pop("storage_map", None)
@@ -251,23 +316,8 @@ def opfromgraph(*inputs):
251316
def opfromgraph(*inputs):
252317
return fgraph_fn(*inputs)
253318

254-
return opfromgraph
255-
256-
257-
@numba_funcify.register(FunctionGraph)
258-
def numba_funcify_FunctionGraph(
259-
fgraph,
260-
node=None,
261-
fgraph_name="numba_funcified_fgraph",
262-
**kwargs,
263-
):
264-
return fgraph_to_python(
265-
fgraph,
266-
numba_funcify,
267-
type_conversion_fn=numba_typify,
268-
fgraph_name=fgraph_name,
269-
**kwargs,
270-
)
319+
# We can't cache this correctly until we can define a key for it
320+
return opfromgraph, None
271321

272322

273323
@numba_funcify.register(DeepCopyOp)

pytensor/link/numba/dispatch/blockwise.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
from hashlib import sha256
23
from typing import cast
34

45
from numba.core.extending import overload
@@ -30,12 +31,17 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
3031
cast(tuple[TensorVariable], node.inputs[:nin]),
3132
propagate_unbatched_core_inputs=True,
3233
)
33-
core_op_fn = numba_funcify(
34+
core_op_fn_and_key = numba_funcify(
3435
core_op,
3536
node=core_node,
3637
parent_node=node,
3738
**kwargs,
3839
)
40+
if isinstance(core_op_fn_and_key, tuple):
41+
core_op_fn, core_op_key = core_op_fn_and_key
42+
else:
43+
# Assume we can cache core_op_fn
44+
core_op_fn, core_op_key = core_op_fn_and_key, 0
3945
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
4046

4147
batch_ndim = blockwise_op.batch_ndim(node)
@@ -90,4 +96,22 @@ def blockwise(*inputs_and_core_shapes):
9096
def ov_blockwise(*inputs_and_core_shapes):
9197
return blockwise_wrapper
9298

93-
return blockwise
99+
if core_op_key is None:
100+
# We were told the scalar op cannot be cached
101+
blockwise_key = None
102+
else:
103+
blockwise_key = "_".join(
104+
map(
105+
str,
106+
(
107+
type(op),
108+
type(op.scalar_op),
109+
tuple(op.inplace_pattern.items()),
110+
tuple(getattr(op.scalar_op, "props_dict", lambda: {})().items()),
111+
core_op_key,
112+
),
113+
)
114+
)
115+
blockwise_key = sha256(blockwise_key.encode()).hexdigest()
116+
117+
return blockwise, blockwise_key

0 commit comments

Comments
 (0)