|
5 | 5 | from pytensor import config, function, grad
|
6 | 6 | from pytensor.compile.mode import Mode, get_mode
|
7 | 7 | from pytensor.graph.fg import FunctionGraph
|
| 8 | +from pytensor.scalar import Log1p |
8 | 9 | from pytensor.scan.basic import scan
|
9 | 10 | from pytensor.scan.op import Scan
|
10 | 11 | from pytensor.scan.utils import until
|
| 12 | +from pytensor.tensor import log, vector |
| 13 | +from pytensor.tensor.elemwise import Elemwise |
11 | 14 | from pytensor.tensor.random.utils import RandomStream
|
12 | 15 | from tests import unittest_tools as utt
|
13 | 16 | from tests.link.numba.test_basic import compare_numba_and_py
|
@@ -417,3 +420,25 @@ def inner_fct(seq, state_old, state_current):
|
417 | 420 | compare_numba_and_py(
|
418 | 421 | out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
|
419 | 422 | )
|
| 423 | + |
| 424 | + |
| 425 | +def test_inner_graph_optimized(): |
| 426 | + """Test that inner graph of Scan is optimized""" |
| 427 | + xs = vector("xs") |
| 428 | + seq, _ = scan( |
| 429 | + fn=lambda x: log(1 + x), |
| 430 | + sequences=[xs], |
| 431 | + mode=get_mode("NUMBA"), |
| 432 | + ) |
| 433 | + |
| 434 | + # Disable scan pushout, in which case the whole scan is replaced by an Elemwise |
| 435 | + f = function([xs], seq, mode=get_mode("NUMBA").excluding("scan_pushout")) |
| 436 | + (scan_node,) = [ |
| 437 | + node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan) |
| 438 | + ] |
| 439 | + inner_scan_nodes = scan_node.op.fgraph.apply_nodes |
| 440 | + assert len(inner_scan_nodes) == 1 |
| 441 | + (inner_scan_node,) = scan_node.op.fgraph.apply_nodes |
| 442 | + assert isinstance(inner_scan_node.op, Elemwise) and isinstance( |
| 443 | + inner_scan_node.op.scalar_op, Log1p |
| 444 | + ) |
0 commit comments