Skip to content

Commit 799edf0

Browse files
committed
Cache stuff hazardly
1 parent 32311ba commit 799edf0

File tree

3 files changed

+160
-6
lines changed

3 files changed

+160
-6
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,12 @@ def axis_apply_fn(x):
264264

265265
@numba_funcify.register(Elemwise)
266266
def numba_funcify_Elemwise(op, node, **kwargs):
267+
# op = getattr(np, str(op.scalar_op).lower())
268+
# @numba_njit
269+
# def elemwise_is_numpy(x):
270+
# return op(x)
271+
# return elemwise_is_numpy
272+
267273
scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs]
268274
scalar_node = op.scalar_op.make_node(*scalar_inputs)
269275

@@ -276,7 +282,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
276282

277283
nin = len(node.inputs)
278284
nout = len(node.outputs)
279-
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
285+
core_op_fn = store_core_outputs(scalar_op_fn, op.scalar_op, nin=nin, nout=nout)
280286

281287
input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs)
282288
output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs)

pytensor/link/numba/dispatch/vectorize_codegen.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import base64
44
import pickle
55
from collections.abc import Callable, Sequence
6+
from hashlib import sha256
67
from textwrap import indent
78
from typing import Any, cast
89

@@ -15,15 +16,19 @@
1516
from numba.core.types.misc import NoneType
1617
from numba.np import arrayobj
1718

19+
from pytensor.graph.op import HasInnerGraph
1820
from pytensor.link.numba.dispatch import basic as numba_basic
19-
from pytensor.link.utils import compile_function_src
21+
from pytensor.link.numba.super_utils import compile_function_src2
22+
from pytensor.scalar import ScalarOp
2023

2124

2225
def encode_literals(literals: Sequence) -> str:
2326
return base64.encodebytes(pickle.dumps(literals)).decode()
2427

2528

26-
def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable:
29+
def store_core_outputs(
30+
core_op_fn: Callable, core_op: ScalarOp, nin: int, nout: int
31+
) -> Callable:
2732
"""Create a Numba function that wraps a core function and stores its vectorized outputs.
2833
2934
@njit
@@ -52,9 +57,14 @@ def store_core_outputs({inp_signature}, {out_signature}):
5257
{indent(store_outputs, " " * 4)}
5358
"""
5459
global_env = {"core_op_fn": core_op_fn}
55-
func = compile_function_src(
56-
func_src, "store_core_outputs", {**globals(), **global_env}
57-
)
60+
# func = compile_function_src(
61+
# func_src, "store_core_outputs", {**globals(), **global_env},
62+
# )
63+
if isinstance(core_op, HasInnerGraph):
64+
key = sha256(core_op.c_code_template.encode()).hexdigest()
65+
else:
66+
key = str(core_op)
67+
func = compile_function_src2(key, func_src, "store_core_outputs", global_env)
5868
return cast(Callable, numba_basic.numba_njit(func))
5969

6070

pytensor/link/numba/super_utils.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import importlib
2+
import os
3+
import sys
4+
import tempfile
5+
from collections.abc import Callable
6+
from typing import Any
7+
8+
import numba
9+
import numba.core.caching
10+
from numba.core.caching import CacheImpl
11+
12+
13+
class PyTensorLoader(importlib.abc.SourceLoader):
14+
def __init__(self):
15+
# Key is "pytensor_generated_" + hash of pytensor graph
16+
self._module_sources = {}
17+
self._module_globals = {}
18+
self._module_locals = {}
19+
20+
def get_source(self, fullname):
21+
if fullname not in self._module_sources:
22+
raise ImportError()
23+
return self._module_sources[fullname]
24+
25+
def get_data(self, path):
26+
if path not in self._module_sources:
27+
raise ImportError()
28+
return self._module_sources[path].encode("utf-8")
29+
30+
def get_filename(self, path):
31+
if path not in self._module_sources:
32+
raise ImportError()
33+
return path
34+
35+
def add_module(self, name, src, global_env, local_env):
36+
self._module_sources[name] = src
37+
self._module_globals[name] = global_env
38+
self._module_locals[name] = local_env
39+
40+
def exec_module(self, module):
41+
name = module.__name__
42+
variables = module.__dict__
43+
variables.update(self._module_globals[name])
44+
variables.update(self._module_locals[name])
45+
code = compile(self._module_sources[name], name, "exec")
46+
exec(code, variables)
47+
48+
def create_module(self, spec):
49+
return None
50+
51+
52+
pytensor_loader = PyTensorLoader()
53+
54+
55+
def load_module(key, src, global_env, local_env):
56+
pytensor_loader.add_module(key, src, global_env, local_env)
57+
spec = importlib.util.spec_from_loader(key, pytensor_loader)
58+
module = importlib.util.module_from_spec(spec)
59+
spec.loader.exec_module(module)
60+
sys.modules[key] = module
61+
return module
62+
63+
64+
class NumbaPyTensorCacheLocator(numba.core.caching._CacheLocator):
65+
def __init__(self, py_func, py_file):
66+
# print(f"New locator {py_func=}, {py_file=}")
67+
self._py_func = py_func
68+
self._py_file = py_file
69+
self._hash = py_file
70+
# src_hash = hash(pytensor_loader._module_sources[self._py_file])
71+
# self._hash = hash((src_hash, py_file, pytensor.__version__))
72+
73+
def ensure_cache_path(self):
74+
path = self.get_cache_path()
75+
os.makedirs(path, exist_ok=True)
76+
# Ensure the directory is writable by trying to write a temporary file
77+
tempfile.TemporaryFile(dir=path).close()
78+
79+
def get_cache_path(self):
80+
"""
81+
Return the directory the function is cached in.
82+
"""
83+
return "~/.cache/pytensor"
84+
85+
def get_source_stamp(self):
86+
"""
87+
Get a timestamp representing the source code's freshness.
88+
Can return any picklable Python object.
89+
"""
90+
91+
return self._hash
92+
93+
def get_disambiguator(self):
94+
"""
95+
Get a string disambiguator for this locator's function.
96+
It should allow disambiguating different but similarly-named functions.
97+
"""
98+
return None
99+
100+
@classmethod
101+
def from_function(cls, py_func, py_file):
102+
"""
103+
Create a locator instance for the given function located in the
104+
given file.
105+
"""
106+
if py_func.__module__ in pytensor_loader._module_sources:
107+
return cls(py_func, py_file)
108+
109+
110+
CacheImpl._locator_classes.append(NumbaPyTensorCacheLocator)
111+
112+
113+
def compile_function_src2(
114+
key: str,
115+
src: str,
116+
function_name: str,
117+
global_env: dict[Any, Any] | None = None,
118+
local_env: dict[Any, Any] | None = None,
119+
) -> Callable:
120+
# with NamedTemporaryFile(delete=False) as f:
121+
# filename = f.name
122+
# f.write(src.encode())
123+
124+
if global_env is None:
125+
global_env = {}
126+
127+
if local_env is None:
128+
local_env = {}
129+
130+
# mod_code = compile(src, filename, mode="exec")
131+
# exec(mod_code, global_env, local_env)
132+
# print(key, src)
133+
module = load_module(key, src, global_env, local_env)
134+
res = getattr(module, function_name)
135+
136+
# res = cast(Callable, res)
137+
# res.__source__ = src # type: ignore
138+
return res

0 commit comments

Comments
 (0)