Skip to content

Commit 1d5b1d9

Browse files
committed
Replace uses of in2out and out2in by a depth-first search rewriter
1 parent aa7e4d6 commit 1d5b1d9

File tree

17 files changed

+86
-101
lines changed

17 files changed

+86
-101
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@
2727
from pytensor.graph.fg import FunctionGraph, Output
2828
from pytensor.graph.op import Op
2929
from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars
30-
from pytensor.graph.traversal import applys_between, toposort, vars_between
30+
from pytensor.graph.traversal import (
31+
apply_ancestors,
32+
applys_between,
33+
toposort,
34+
vars_between,
35+
)
3136
from pytensor.graph.utils import AssocList, InconsistencyError
3237
from pytensor.misc.ordered_set import OrderedSet
3338
from pytensor.utils import flatten
@@ -1995,12 +2000,13 @@ class WalkingGraphRewriter(NodeProcessingGraphRewriter):
19952000
def __init__(
19962001
self,
19972002
node_rewriter: NodeRewriter,
1998-
order: Literal["out_to_in", "in_to_out"] = "in_to_out",
2003+
order: Literal["out_to_in", "in_to_out", "dfs"] = "in_to_out",
19992004
ignore_newtrees: bool = False,
20002005
failure_callback: FailureCallbackType | None = None,
20012006
):
2002-
if order not in ("out_to_in", "in_to_out"):
2003-
raise ValueError("order must be 'out_to_in' or 'in_to_out'")
2007+
valid_orders = ("out_to_in", "in_to_out", "dfs")
2008+
if order not in valid_orders:
2009+
raise ValueError(f"order must be one of {valid_orders}, got {order}")
20042010
self.order = order
20052011
super().__init__(node_rewriter, ignore_newtrees, failure_callback)
20062012

@@ -2010,7 +2016,11 @@ def apply(self, fgraph, start_from=None):
20102016
callback_before = fgraph.execute_callbacks_time
20112017
nb_nodes_start = len(fgraph.apply_nodes)
20122018
t0 = time.perf_counter()
2013-
q = deque(toposort(start_from))
2019+
q = deque(
2020+
apply_ancestors(start_from)
2021+
if (self.order == "dfs")
2022+
else toposort(start_from)
2023+
)
20142024
io_t = time.perf_counter() - t0
20152025

20162026
def importer(node):
@@ -2134,6 +2144,7 @@ def walking_rewriter(
21342144

21352145
in2out = partial(walking_rewriter, "in_to_out")
21362146
out2in = partial(walking_rewriter, "out_to_in")
2147+
dfs_rewriter = partial(walking_rewriter, "dfs")
21372148

21382149

21392150
class ChangeTracker(Feature):

pytensor/scan/rewriting.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
EquilibriumGraphRewriter,
3030
GraphRewriter,
3131
copy_stack_trace,
32-
in2out,
32+
dfs_rewriter,
3333
node_rewriter,
3434
)
3535
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
@@ -2558,15 +2558,15 @@ def apply(self, fgraph, start_from=None):
25582558
# ScanSaveMem should execute only once per node.
25592559
optdb.register(
25602560
"scan_save_mem_prealloc",
2561-
in2out(scan_save_mem_prealloc, ignore_newtrees=True),
2561+
dfs_rewriter(scan_save_mem_prealloc, ignore_newtrees=True),
25622562
"fast_run",
25632563
"scan",
25642564
"scan_save_mem",
25652565
position=1.61,
25662566
)
25672567
optdb.register(
25682568
"scan_save_mem_no_prealloc",
2569-
in2out(scan_save_mem_no_prealloc, ignore_newtrees=True),
2569+
dfs_rewriter(scan_save_mem_no_prealloc, ignore_newtrees=True),
25702570
"numba",
25712571
"jax",
25722572
"pytorch",
@@ -2587,7 +2587,7 @@ def apply(self, fgraph, start_from=None):
25872587

25882588
scan_seqopt1.register(
25892589
"scan_remove_constants_and_unused_inputs0",
2590-
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2590+
dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
25912591
"remove_constants_and_unused_inputs_scan",
25922592
"fast_run",
25932593
"scan",
@@ -2596,7 +2596,7 @@ def apply(self, fgraph, start_from=None):
25962596

25972597
scan_seqopt1.register(
25982598
"scan_push_out_non_seq",
2599-
in2out(scan_push_out_non_seq, ignore_newtrees=True),
2599+
dfs_rewriter(scan_push_out_non_seq, ignore_newtrees=True),
26002600
"scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name
26012601
"fast_run",
26022602
"scan",
@@ -2606,7 +2606,7 @@ def apply(self, fgraph, start_from=None):
26062606

26072607
scan_seqopt1.register(
26082608
"scan_push_out_seq",
2609-
in2out(scan_push_out_seq, ignore_newtrees=True),
2609+
dfs_rewriter(scan_push_out_seq, ignore_newtrees=True),
26102610
"scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name
26112611
"fast_run",
26122612
"scan",
@@ -2617,7 +2617,7 @@ def apply(self, fgraph, start_from=None):
26172617

26182618
scan_seqopt1.register(
26192619
"scan_push_out_dot1",
2620-
in2out(scan_push_out_dot1, ignore_newtrees=True),
2620+
dfs_rewriter(scan_push_out_dot1, ignore_newtrees=True),
26212621
"scan_pushout_dot1", # For backcompat: so it can be tagged with old name
26222622
"fast_run",
26232623
"more_mem",
@@ -2630,7 +2630,7 @@ def apply(self, fgraph, start_from=None):
26302630
scan_seqopt1.register(
26312631
"scan_push_out_add",
26322632
# TODO: Perhaps this should be an `EquilibriumGraphRewriter`?
2633-
in2out(scan_push_out_add, ignore_newtrees=False),
2633+
dfs_rewriter(scan_push_out_add, ignore_newtrees=False),
26342634
"scan_pushout_add", # For backcompat: so it can be tagged with old name
26352635
"fast_run",
26362636
"more_mem",
@@ -2641,22 +2641,22 @@ def apply(self, fgraph, start_from=None):
26412641

26422642
scan_eqopt2.register(
26432643
"while_scan_merge_subtensor_last_element",
2644-
in2out(while_scan_merge_subtensor_last_element, ignore_newtrees=True),
2644+
dfs_rewriter(while_scan_merge_subtensor_last_element, ignore_newtrees=True),
26452645
"fast_run",
26462646
"scan",
26472647
)
26482648

26492649
scan_eqopt2.register(
26502650
"constant_folding_for_scan2",
2651-
in2out(constant_folding, ignore_newtrees=True),
2651+
dfs_rewriter(constant_folding, ignore_newtrees=True),
26522652
"fast_run",
26532653
"scan",
26542654
)
26552655

26562656

26572657
scan_eqopt2.register(
26582658
"scan_remove_constants_and_unused_inputs1",
2659-
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2659+
dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
26602660
"remove_constants_and_unused_inputs_scan",
26612661
"fast_run",
26622662
"scan",
@@ -2671,23 +2671,23 @@ def apply(self, fgraph, start_from=None):
26712671
# After Merge optimization
26722672
scan_eqopt2.register(
26732673
"scan_remove_constants_and_unused_inputs2",
2674-
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2674+
dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
26752675
"remove_constants_and_unused_inputs_scan",
26762676
"fast_run",
26772677
"scan",
26782678
)
26792679

26802680
scan_eqopt2.register(
26812681
"scan_merge_inouts",
2682-
in2out(scan_merge_inouts, ignore_newtrees=True),
2682+
dfs_rewriter(scan_merge_inouts, ignore_newtrees=True),
26832683
"fast_run",
26842684
"scan",
26852685
)
26862686

26872687
# After everything else
26882688
scan_eqopt2.register(
26892689
"scan_remove_constants_and_unused_inputs3",
2690-
in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
2690+
dfs_rewriter(remove_constants_and_unused_inputs_scan, ignore_newtrees=True),
26912691
"remove_constants_and_unused_inputs_scan",
26922692
"fast_run",
26932693
"scan",

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pytensor.compile import optdb
55
from pytensor.graph import Constant, graph_inputs
6-
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
6+
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter, node_rewriter
77
from pytensor.scan.op import Scan
88
from pytensor.scan.rewriting import scan_seqopt1
99
from pytensor.tensor._linalg.solve.tridiagonal import (
@@ -244,7 +244,7 @@ def scan_split_non_sequence_decomposition_and_solve(fgraph, node):
244244

245245
scan_seqopt1.register(
246246
scan_split_non_sequence_decomposition_and_solve.__name__,
247-
in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
247+
dfs_rewriter(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True),
248248
"fast_run",
249249
"scan",
250250
"scan_pushout",
@@ -261,7 +261,7 @@ def reuse_decomposition_multiple_solves_jax(fgraph, node):
261261

262262
optdb["specialize"].register(
263263
reuse_decomposition_multiple_solves_jax.__name__,
264-
in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
264+
dfs_rewriter(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True),
265265
"jax",
266266
use_db_name_as_tag=False,
267267
)
@@ -276,7 +276,9 @@ def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node):
276276

277277
scan_seqopt1.register(
278278
scan_split_non_sequence_decomposition_and_solve_jax.__name__,
279-
in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True),
279+
dfs_rewriter(
280+
scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True
281+
),
280282
"jax",
281283
use_db_name_as_tag=False,
282284
position=2,

pytensor/tensor/random/rewriting/basic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from pytensor.configdefaults import config
55
from pytensor.graph import ancestors
66
from pytensor.graph.op import compute_test_value
7-
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
7+
from pytensor.graph.rewriting.basic import (
8+
copy_stack_trace,
9+
dfs_rewriter,
10+
node_rewriter,
11+
)
812
from pytensor.tensor import NoneConst, TensorVariable
913
from pytensor.tensor.basic import constant
1014
from pytensor.tensor.elemwise import DimShuffle
@@ -57,7 +61,7 @@ def random_make_inplace(fgraph, node):
5761

5862
optdb.register(
5963
"random_make_inplace",
60-
in2out(random_make_inplace, ignore_newtrees=True),
64+
dfs_rewriter(random_make_inplace, ignore_newtrees=True),
6165
"fast_run",
6266
"inplace",
6367
position=50.9,

pytensor/tensor/random/rewriting/jax.py

Lines changed: 11 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
from pytensor.compile import optdb
44
from pytensor.graph import Constant
5-
from pytensor.graph.rewriting.basic import in2out, node_rewriter
6-
from pytensor.graph.rewriting.db import SequenceDB
5+
from pytensor.graph.rewriting.basic import dfs_rewriter, in2out, node_rewriter
76
from pytensor.tensor import abs as abs_t
87
from pytensor.tensor import broadcast_arrays, exp, floor, log, log1p, reciprocal, sqrt
98
from pytensor.tensor.basic import (
@@ -179,51 +178,16 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
179178
return new_op.make_node(rng, size, a_vector_param, *other_params).outputs
180179

181180

182-
random_vars_opt = SequenceDB()
183-
random_vars_opt.register(
184-
"lognormal_from_normal",
185-
in2out(lognormal_from_normal),
186-
"jax",
187-
)
188-
random_vars_opt.register(
189-
"halfnormal_from_normal",
190-
in2out(halfnormal_from_normal),
191-
"jax",
192-
)
193-
random_vars_opt.register(
194-
"geometric_from_uniform",
195-
in2out(geometric_from_uniform),
196-
"jax",
197-
)
198-
random_vars_opt.register(
199-
"negative_binomial_from_gamma_poisson",
200-
in2out(negative_binomial_from_gamma_poisson),
201-
"jax",
202-
)
203-
random_vars_opt.register(
204-
"inverse_gamma_from_gamma",
205-
in2out(inverse_gamma_from_gamma),
206-
"jax",
207-
)
208-
random_vars_opt.register(
209-
"generalized_gamma_from_gamma",
210-
in2out(generalized_gamma_from_gamma),
211-
"jax",
212-
)
213-
random_vars_opt.register(
214-
"wald_from_normal_uniform",
215-
in2out(wald_from_normal_uniform),
216-
"jax",
217-
)
218-
random_vars_opt.register(
219-
"beta_binomial_from_beta_binomial",
220-
in2out(beta_binomial_from_beta_binomial),
221-
"jax",
222-
)
223-
random_vars_opt.register(
224-
"materialize_implicit_arange_choice_without_replacement",
225-
in2out(materialize_implicit_arange_choice_without_replacement),
226-
"jax",
181+
random_vars_opt = dfs_rewriter(
182+
lognormal_from_normal,
183+
halfnormal_from_normal,
184+
geometric_from_uniform,
185+
negative_binomial_from_gamma_poisson,
186+
inverse_gamma_from_gamma,
187+
generalized_gamma_from_gamma,
188+
wald_from_normal_uniform,
189+
beta_binomial_from_beta_binomial,
190+
materialize_implicit_arange_choice_without_replacement,
227191
)
228192
optdb.register("jax_random_vars_rewrites", random_vars_opt, "jax", position=110)
229193

pytensor/tensor/random/rewriting/numba.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pytensor.compile import optdb
22
from pytensor.graph import node_rewriter
3-
from pytensor.graph.rewriting.basic import out2in
3+
from pytensor.graph.rewriting.basic import dfs_rewriter
44
from pytensor.tensor import as_tensor, constant
55
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
66
from pytensor.tensor.rewriting.shape import ShapeFeature
@@ -82,7 +82,7 @@ def introduce_explicit_core_shape_rv(fgraph, node):
8282

8383
optdb.register(
8484
introduce_explicit_core_shape_rv.__name__,
85-
out2in(introduce_explicit_core_shape_rv),
85+
dfs_rewriter(introduce_explicit_core_shape_rv),
8686
"numba",
8787
position=100,
8888
)

pytensor/tensor/rewriting/basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
NodeRewriter,
3636
Rewriter,
3737
copy_stack_trace,
38+
dfs_rewriter,
3839
in2out,
3940
node_rewriter,
4041
)
@@ -538,7 +539,7 @@ def local_alloc_empty_to_zeros(fgraph, node):
538539

539540
compile.optdb.register(
540541
"local_alloc_empty_to_zeros",
541-
in2out(local_alloc_empty_to_zeros),
542+
dfs_rewriter(local_alloc_empty_to_zeros),
542543
# After move to gpu and merge2, before inplace.
543544
"alloc_empty_to_zeros",
544545
position=49.3,

pytensor/tensor/rewriting/blas.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
EquilibriumGraphRewriter,
7878
GraphRewriter,
7979
copy_stack_trace,
80-
in2out,
80+
dfs_rewriter,
8181
node_rewriter,
8282
)
8383
from pytensor.graph.rewriting.db import SequenceDB
@@ -721,7 +721,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
721721
# fast_compile is needed to have GpuDot22 created.
722722
blas_optdb.register(
723723
"local_dot_to_dot22",
724-
in2out(local_dot_to_dot22),
724+
dfs_rewriter(local_dot_to_dot22),
725725
"fast_run",
726726
"fast_compile",
727727
position=0,
@@ -744,7 +744,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
744744
)
745745

746746

747-
blas_opt_inplace = in2out(
747+
blas_opt_inplace = dfs_rewriter(
748748
local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace"
749749
)
750750
optdb.register(
@@ -883,7 +883,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
883883
# dot22scalar and gemm give more speed up then dot22scalar
884884
blas_optdb.register(
885885
"local_dot22_to_dot22scalar",
886-
in2out(local_dot22_to_dot22scalar),
886+
dfs_rewriter(local_dot22_to_dot22scalar),
887887
"fast_run",
888888
position=12,
889889
)

0 commit comments

Comments
 (0)