@@ -44,32 +44,23 @@ def rewrite_graph(
44
44
"""
45
45
from pytensor .compile import optdb
46
46
47
- return_fgraph = False
48
47
if isinstance (graph , FunctionGraph ):
49
48
fgraph = graph
50
- return_fgraph = True
51
49
else :
52
- if isinstance (graph , list | tuple ):
53
- outputs = graph
54
- else :
55
- assert isinstance (graph , Variable )
56
- outputs = [graph ]
57
-
50
+ outputs = [graph ] if isinstance (graph , Variable ) else graph
58
51
fgraph = FunctionGraph (outputs = outputs , clone = clone )
59
52
60
53
query_rewrites = optdb .query (RewriteDatabaseQuery (include = include , ** kwargs ))
61
- _ = query_rewrites .rewrite (fgraph )
54
+ query_rewrites .rewrite (fgraph )
62
55
63
- if custom_rewrite :
56
+ if custom_rewrite is not None :
64
57
custom_rewrite .rewrite (fgraph )
65
58
66
- if return_fgraph :
59
+ if isinstance ( graph , FunctionGraph ) :
67
60
return fgraph
68
- else :
69
- if isinstance (graph , list | tuple ):
70
- return fgraph .outputs
71
- else :
72
- return fgraph .outputs [0 ]
61
+ if isinstance (graph , Variable ):
62
+ return fgraph .outputs [0 ]
63
+ return fgraph .outputs
73
64
74
65
75
66
def is_same_graph_with_merge (
@@ -90,14 +81,10 @@ def is_same_graph_with_merge(
90
81
"""
91
82
from pytensor .graph .rewriting .basic import MergeOptimizer
92
83
93
- if givens is None :
94
- givens = {}
95
- givens = dict (givens )
84
+ givens = {} if givens is None else dict (givens )
96
85
97
86
# Copy variables since the MergeOptimizer will modify them.
98
- copied = copy .deepcopy ((var1 , var2 , givens ))
99
- vars = copied [0 :2 ]
100
- givens = copied [2 ]
87
+ * vars , givens = copy .deepcopy ((var1 , var2 , givens ))
101
88
# Create FunctionGraph.
102
89
inputs = list (graph_inputs (vars ))
103
90
# The clone isn't needed as we did a deepcopy and we cloning will
@@ -120,8 +107,7 @@ def is_same_graph_with_merge(
120
107
# Comparing two single-Variable graphs: they are equal if they are
121
108
# the same Variable.
122
109
return vars_replaced [0 ] == vars_replaced [1 ]
123
- else :
124
- return o1 is o2
110
+ return o1 is o2
125
111
126
112
127
113
def is_same_graph (
@@ -171,71 +157,58 @@ def is_same_graph(
171
157
====== ====== ====== ======
172
158
173
159
"""
174
- use_equal_computations = True
175
-
176
- if givens is None :
177
- givens = {}
178
- givens = dict (givens )
160
+ givens = {} if givens is None else dict (givens )
179
161
180
162
# Get result from the merge-based function.
181
163
rval1 = is_same_graph_with_merge (var1 = var1 , var2 = var2 , givens = givens )
182
164
183
- if givens :
184
- # We need to build the `in_xs` and `in_ys` lists. To do this, we need
185
- # to be able to tell whether a variable belongs to the computational
186
- # graph of `var1` or `var2`.
187
- # The typical case we want to handle is when `to_replace` belongs to
188
- # one of these graphs, and `replace_by` belongs to the other one. In
189
- # other situations, the current implementation of `equal_computations`
190
- # is probably not appropriate, so we do not call it.
191
- ok = True
192
- in_xs = []
193
- in_ys = []
194
- # Compute the sets of all variables found in each computational graph.
195
- inputs_var1 = graph_inputs ([var1 ])
196
- inputs_var2 = graph_inputs ([var2 ])
197
- all_vars = [
198
- set (vars_between (v_i , v_o ))
199
- for v_i , v_o in ((inputs_var1 , [var1 ]), (inputs_var2 , [var2 ]))
200
- ]
201
-
202
- def in_var (x , k ):
203
- # Return True iff `x` is in computation graph of variable `vark`.
204
- return x in all_vars [k - 1 ]
165
+ if not givens :
166
+ rval2 = equal_computations (xs = [var1 ], ys = [var2 ])
167
+ assert rval1 == rval2
168
+ return rval1
169
+
170
+ # We need to build the `in_xs` and `in_ys` lists. To do this, we need
171
+ # to be able to tell whether a variable belongs to the computational
172
+ # graph of `var1` or `var2`.
173
+ # The typical case we want to handle is when `to_replace` belongs to
174
+ # one of these graphs, and `replace_by` belongs to the other one. In
175
+ # other situations, the current implementation of `equal_computations`
176
+ # is probably not appropriate, so we do not call it.
177
+ use_equal_computations = True
178
+ in_xs = []
179
+ in_ys = []
180
+ # Compute the sets of all variables found in each computational graph.
181
+ inputs_var1 = graph_inputs ([var1 ])
182
+ inputs_var2 = graph_inputs ([var2 ])
183
+ all_vars1 = set (vars_between (inputs_var1 , [var1 ]))
184
+ all_vars2 = set (vars_between (inputs_var2 , [var2 ]))
205
185
206
- for to_replace , replace_by in givens .items ():
207
- # Map a substitution variable to the computational graphs it
208
- # belongs to.
209
- inside = {
210
- v : [in_var (v , k ) for k in (1 , 2 )] for v in (to_replace , replace_by )
211
- }
212
- if (
213
- inside [to_replace ][0 ]
214
- and not inside [to_replace ][1 ]
215
- and inside [replace_by ][1 ]
216
- and not inside [replace_by ][0 ]
217
- ):
218
- # Substitute variable in `var1` by one from `var2`.
219
- in_xs .append (to_replace )
220
- in_ys .append (replace_by )
221
- elif (
222
- inside [to_replace ][1 ]
223
- and not inside [to_replace ][0 ]
224
- and inside [replace_by ][0 ]
225
- and not inside [replace_by ][1 ]
226
- ):
227
- # Substitute variable in `var2` by one from `var1`.
228
- in_xs .append (replace_by )
229
- in_ys .append (to_replace )
230
- else :
231
- ok = False
232
- break
233
- if not ok :
234
- # We cannot directly use `equal_computations`.
186
+ for to_replace , replace_by in givens .items ():
187
+ # Map a substitution variable to the computational graphs it
188
+ # belongs to.
189
+ inside = {v : [v in all_vars1 , v in all_vars2 ] for v in (to_replace , replace_by )}
190
+ if (
191
+ inside [to_replace ][0 ]
192
+ and not inside [to_replace ][1 ]
193
+ and inside [replace_by ][1 ]
194
+ and not inside [replace_by ][0 ]
195
+ ):
196
+ # Substitute variable in `var1` by one from `var2`.
197
+ in_xs .append (to_replace )
198
+ in_ys .append (replace_by )
199
+ elif (
200
+ inside [to_replace ][1 ]
201
+ and not inside [to_replace ][0 ]
202
+ and inside [replace_by ][0 ]
203
+ and not inside [replace_by ][1 ]
204
+ ):
205
+ # Substitute variable in `var2` by one from `var1`.
206
+ in_xs .append (replace_by )
207
+ in_ys .append (to_replace )
208
+ else :
235
209
use_equal_computations = False
236
- else :
237
- in_xs = None
238
- in_ys = None
210
+ break
211
+
239
212
if use_equal_computations :
240
213
rval2 = equal_computations (xs = [var1 ], ys = [var2 ], in_xs = in_xs , in_ys = in_ys )
241
214
assert rval2 == rval1
0 commit comments