Skip to content

Commit 801845c

Browse files
committed
Skip scan rewrites if there is no Scan Op in the graph
1 parent 051b32d commit 801845c

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

pytensor/graph/rewriting/db.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,25 +310,29 @@ class EquilibriumDB(RewriteDatabase):
310310
"""
311311

312312
def __init__(
313-
self, ignore_newtrees: bool = True, tracks_on_change_inputs: bool = False
313+
self,
314+
ignore_newtrees: bool = True,
315+
tracks_on_change_inputs: bool = False,
316+
eq_rewriter_class=pytensor_rewriting.EquilibriumGraphRewriter,
314317
):
315318
"""
316319
317320
Parameters
318321
----------
319322
ignore_newtrees
320-
If ``False``, apply rewrites to new nodes introduced during
321-
rewriting.
322-
323+
If ``False``, apply rewrites to new nodes introduced during rewritings.
323324
tracks_on_change_inputs
324325
If ``True``, re-apply rewrites on nodes with changed inputs.
326+
eq_rewriter_class: EquilibriumGraphRewriter class, optional
327+
The class used to create the equilibrium rewriter. Defaults to EquilibriumGraphRewriter.
325328
326329
"""
327330
super().__init__()
328331
self.ignore_newtrees = ignore_newtrees
329332
self.tracks_on_change_inputs = tracks_on_change_inputs
330333
self.__final__: dict[str, bool] = {}
331334
self.__cleanup__: dict[str, bool] = {}
335+
self.eq_rewriter_class = eq_rewriter_class
332336

333337
def register(
334338
self,
@@ -360,7 +364,7 @@ def query(self, *tags, **kwtags):
360364
final_rewriters = None
361365
if len(cleanup_rewriters) == 0:
362366
cleanup_rewriters = None
363-
return pytensor_rewriting.EquilibriumGraphRewriter(
367+
return self.eq_rewriter_class(
364368
rewriters,
365369
max_use_ratio=config.optdb__max_use_ratio,
366370
ignore_newtrees=self.ignore_newtrees,

pytensor/scan/rewriting.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytensor.graph.op import compute_test_value
3131
from pytensor.graph.replace import clone_replace
3232
from pytensor.graph.rewriting.basic import (
33+
EquilibriumGraphRewriter,
3334
GraphRewriter,
3435
copy_stack_trace,
3536
in2out,
@@ -2517,12 +2518,21 @@ def scan_push_out_dot1(fgraph, node):
25172518
return False
25182519

25192520

2521+
class ScanEquilibriumGraphRewriter(EquilibriumGraphRewriter):
2522+
"""Subclass of EquilibriumGraphRewriter that aborts early if there are no Scan Ops in the graph"""
2523+
2524+
def apply(self, fgraph, start_from=None):
2525+
if not any(isinstance(node.op, Scan) for node in fgraph.apply_nodes):
2526+
return
2527+
super().apply(fgraph=fgraph, start_from=start_from)
2528+
2529+
25202530
# I've added an equilibrium because later scan optimization in the sequence
25212531
# can make it such that earlier optimizations should apply. However, in
25222532
# general I do not expect the sequence to run more then once
2523-
scan_eqopt1 = EquilibriumDB()
2533+
scan_eqopt1 = EquilibriumDB(eq_rewriter_class=ScanEquilibriumGraphRewriter)
25242534
scan_seqopt1 = SequenceDB()
2525-
scan_eqopt2 = EquilibriumDB()
2535+
scan_eqopt2 = EquilibriumDB(eq_rewriter_class=ScanEquilibriumGraphRewriter)
25262536

25272537
# scan_eqopt1 before ShapeOpt at 0.1
25282538
# This is needed to don't have ShapeFeature trac old Scan that we

0 commit comments

Comments
 (0)