Skip to content

Commit 553673f

Browse files
committed
Cache the whole fgraph
1 parent 903f26c commit 553673f

File tree

4 files changed

+5
-12
lines changed

4 files changed

+5
-12
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def type_conversion_and_key_collection(value, variable, **kwargs):
275275
fgraph_key = sha256(
276276
str((tuple(cache_keys), len(fgraph.inputs), len(fgraph.outputs))).encode()
277277
).hexdigest()
278-
return py_func, fgraph_key
278+
return numba_njit(py_func), fgraph_key
279279

280280

281281
@numba_funcify.register(OpFromGraph)

pytensor/link/numba/dispatch/scalar.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ def numba_funcify_Composite(op, node, **kwargs):
238238
composite_fn, fgraph_key = numba_funcify_and_cache_key(
239239
op.fgraph, squeeze_output=True, **kwargs
240240
)
241-
composite_fn = numba_njit(composite_fn)
242241
if fgraph_key is None:
243242
composite_key = None
244243
else:

pytensor/link/numba/dispatch/scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
100100
)
101101
rewriter(fgraph)
102102

103-
scan_inner_func = numba_njit(numba_funcify(op.fgraph))
103+
scan_inner_func = numba_funcify(op.fgraph)
104104

105105
outer_in_names_to_vars = {
106106
(f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs)

pytensor/link/numba/linker.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,12 @@ def __init__(self, *args, vm: bool = False, **kwargs):
1010

1111
def fgraph_convert(self, fgraph, **kwargs):
1212
import pytensor.link.numba.dispatch # noqa
13-
from pytensor.link.numba.compile import numba_funcify
13+
from pytensor.link.numba.cache import numba_njit_and_cache
1414

15-
return numba_funcify(fgraph, **kwargs)
15+
return numba_njit_and_cache(fgraph, **kwargs)[0]
1616

1717
def jit_compile(self, fn):
18-
if self.vm:
19-
return fn
20-
else:
21-
from pytensor.link.numba.cache import numba_njit
22-
23-
jitted_fn = numba_njit(fn, final_function=True)
24-
return jitted_fn
18+
return fn
2519

2620
def create_thunk_inputs(self, storage_map):
2721
return [storage_map[n] for n in self.fgraph.inputs]

0 commit comments

Comments
 (0)