Skip to content

Commit f203dd7

Browse files
committed
Optimize Scan inner graph when compiling to Numba
1 parent 0451e00 commit f203dd7

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

pytensor/link/numba/dispatch/scan.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ def range_arr(x):
4747

4848
@numba_funcify.register(Scan)
4949
def numba_funcify_Scan(op, node, **kwargs):
50+
# Apply inner rewrites
51+
# TODO: Not sure this is the right place to do this, should we have a rewrite that
52+
# explicitly triggers the optimization of the inner graphs of Scan?
53+
# The C-code deffers it to the make_thunk phase
54+
rewriter = op.mode_instance.optimizer
55+
rewriter(op.fgraph)
56+
5057
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
5158

5259
outer_in_names_to_vars = {

tests/link/numba/test_scan.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
from pytensor import config, function, grad
66
from pytensor.compile.mode import Mode, get_mode
77
from pytensor.graph.fg import FunctionGraph
8+
from pytensor.scalar import Log1p
89
from pytensor.scan.basic import scan
910
from pytensor.scan.op import Scan
1011
from pytensor.scan.utils import until
12+
from pytensor.tensor import log, vector
13+
from pytensor.tensor.elemwise import Elemwise
1114
from pytensor.tensor.random.utils import RandomStream
1215
from tests import unittest_tools as utt
1316
from tests.link.numba.test_basic import compare_numba_and_py
@@ -417,3 +420,25 @@ def inner_fct(seq, state_old, state_current):
417420
compare_numba_and_py(
418421
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
419422
)
423+
424+
425+
def test_inner_graph_optimized():
426+
"""Test that inner graph of Scan is optimized"""
427+
xs = vector("xs")
428+
seq, _ = scan(
429+
fn=lambda x: log(1 + x),
430+
sequences=[xs],
431+
mode=get_mode("NUMBA"),
432+
)
433+
434+
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
435+
f = function([xs], seq, mode=get_mode("NUMBA").excluding("scan_pushout"))
436+
(scan_node,) = [
437+
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
438+
]
439+
inner_scan_nodes = scan_node.op.fgraph.apply_nodes
440+
assert len(inner_scan_nodes) == 1
441+
(inner_scan_node,) = scan_node.op.fgraph.apply_nodes
442+
assert isinstance(inner_scan_node.op, Elemwise) and isinstance(
443+
inner_scan_node.op.scalar_op, Log1p
444+
)

0 commit comments

Comments
 (0)