Skip to content

Commit 857cda0

Browse files
committed
Refactor graph/rewriting/utils.py
1 parent 0484e1e commit 857cda0

File tree

1 file changed

+56
-83
lines changed

1 file changed

+56
-83
lines changed

pytensor/graph/rewriting/utils.py

Lines changed: 56 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -44,32 +44,23 @@ def rewrite_graph(
4444
"""
4545
from pytensor.compile import optdb
4646

47-
return_fgraph = False
4847
if isinstance(graph, FunctionGraph):
4948
fgraph = graph
50-
return_fgraph = True
5149
else:
52-
if isinstance(graph, list | tuple):
53-
outputs = graph
54-
else:
55-
assert isinstance(graph, Variable)
56-
outputs = [graph]
57-
50+
outputs = [graph] if isinstance(graph, Variable) else graph
5851
fgraph = FunctionGraph(outputs=outputs, clone=clone)
5952

6053
query_rewrites = optdb.query(RewriteDatabaseQuery(include=include, **kwargs))
61-
_ = query_rewrites.rewrite(fgraph)
54+
query_rewrites.rewrite(fgraph)
6255

63-
if custom_rewrite:
56+
if custom_rewrite is not None:
6457
custom_rewrite.rewrite(fgraph)
6558

66-
if return_fgraph:
59+
if isinstance(graph, FunctionGraph):
6760
return fgraph
68-
else:
69-
if isinstance(graph, list | tuple):
70-
return fgraph.outputs
71-
else:
72-
return fgraph.outputs[0]
61+
if isinstance(graph, Variable):
62+
return fgraph.outputs[0]
63+
return fgraph.outputs
7364

7465

7566
def is_same_graph_with_merge(
@@ -90,14 +81,10 @@ def is_same_graph_with_merge(
9081
"""
9182
from pytensor.graph.rewriting.basic import MergeOptimizer
9283

93-
if givens is None:
94-
givens = {}
95-
givens = dict(givens)
84+
givens = {} if givens is None else dict(givens)
9685

9786
# Copy variables since the MergeOptimizer will modify them.
98-
copied = copy.deepcopy((var1, var2, givens))
99-
vars = copied[0:2]
100-
givens = copied[2]
87+
*vars, givens = copy.deepcopy((var1, var2, givens))
10188
# Create FunctionGraph.
10289
inputs = list(graph_inputs(vars))
10390
# The clone isn't needed as we did a deepcopy and we cloning will
@@ -120,8 +107,7 @@ def is_same_graph_with_merge(
120107
# Comparing two single-Variable graphs: they are equal if they are
121108
# the same Variable.
122109
return vars_replaced[0] == vars_replaced[1]
123-
else:
124-
return o1 is o2
110+
return o1 is o2
125111

126112

127113
def is_same_graph(
@@ -171,71 +157,58 @@ def is_same_graph(
171157
====== ====== ====== ======
172158
173159
"""
174-
use_equal_computations = True
175-
176-
if givens is None:
177-
givens = {}
178-
givens = dict(givens)
160+
givens = {} if givens is None else dict(givens)
179161

180162
# Get result from the merge-based function.
181163
rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)
182164

183-
if givens:
184-
# We need to build the `in_xs` and `in_ys` lists. To do this, we need
185-
# to be able to tell whether a variable belongs to the computational
186-
# graph of `var1` or `var2`.
187-
# The typical case we want to handle is when `to_replace` belongs to
188-
# one of these graphs, and `replace_by` belongs to the other one. In
189-
# other situations, the current implementation of `equal_computations`
190-
# is probably not appropriate, so we do not call it.
191-
ok = True
192-
in_xs = []
193-
in_ys = []
194-
# Compute the sets of all variables found in each computational graph.
195-
inputs_var1 = graph_inputs([var1])
196-
inputs_var2 = graph_inputs([var2])
197-
all_vars = [
198-
set(vars_between(v_i, v_o))
199-
for v_i, v_o in ((inputs_var1, [var1]), (inputs_var2, [var2]))
200-
]
201-
202-
def in_var(x, k):
203-
# Return True iff `x` is in computation graph of variable `vark`.
204-
return x in all_vars[k - 1]
165+
if not givens:
166+
rval2 = equal_computations(xs=[var1], ys=[var2])
167+
assert rval1 == rval2
168+
return rval1
169+
170+
# We need to build the `in_xs` and `in_ys` lists. To do this, we need
171+
# to be able to tell whether a variable belongs to the computational
172+
# graph of `var1` or `var2`.
173+
# The typical case we want to handle is when `to_replace` belongs to
174+
# one of these graphs, and `replace_by` belongs to the other one. In
175+
# other situations, the current implementation of `equal_computations`
176+
# is probably not appropriate, so we do not call it.
177+
use_equal_computations = True
178+
in_xs = []
179+
in_ys = []
180+
# Compute the sets of all variables found in each computational graph.
181+
inputs_var1 = graph_inputs([var1])
182+
inputs_var2 = graph_inputs([var2])
183+
all_vars1 = set(vars_between(inputs_var1, [var1]))
184+
all_vars2 = set(vars_between(inputs_var2, [var2]))
205185

206-
for to_replace, replace_by in givens.items():
207-
# Map a substitution variable to the computational graphs it
208-
# belongs to.
209-
inside = {
210-
v: [in_var(v, k) for k in (1, 2)] for v in (to_replace, replace_by)
211-
}
212-
if (
213-
inside[to_replace][0]
214-
and not inside[to_replace][1]
215-
and inside[replace_by][1]
216-
and not inside[replace_by][0]
217-
):
218-
# Substitute variable in `var1` by one from `var2`.
219-
in_xs.append(to_replace)
220-
in_ys.append(replace_by)
221-
elif (
222-
inside[to_replace][1]
223-
and not inside[to_replace][0]
224-
and inside[replace_by][0]
225-
and not inside[replace_by][1]
226-
):
227-
# Substitute variable in `var2` by one from `var1`.
228-
in_xs.append(replace_by)
229-
in_ys.append(to_replace)
230-
else:
231-
ok = False
232-
break
233-
if not ok:
234-
# We cannot directly use `equal_computations`.
186+
for to_replace, replace_by in givens.items():
187+
# Map a substitution variable to the computational graphs it
188+
# belongs to.
189+
inside = {v: [v in all_vars1, v in all_vars2] for v in (to_replace, replace_by)}
190+
if (
191+
inside[to_replace][0]
192+
and not inside[to_replace][1]
193+
and inside[replace_by][1]
194+
and not inside[replace_by][0]
195+
):
196+
# Substitute variable in `var1` by one from `var2`.
197+
in_xs.append(to_replace)
198+
in_ys.append(replace_by)
199+
elif (
200+
inside[to_replace][1]
201+
and not inside[to_replace][0]
202+
and inside[replace_by][0]
203+
and not inside[replace_by][1]
204+
):
205+
# Substitute variable in `var2` by one from `var1`.
206+
in_xs.append(replace_by)
207+
in_ys.append(to_replace)
208+
else:
235209
use_equal_computations = False
236-
else:
237-
in_xs = None
238-
in_ys = None
210+
break
211+
239212
if use_equal_computations:
240213
rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys)
241214
assert rval2 == rval1

0 commit comments

Comments
 (0)