@@ -72,7 +72,16 @@ def rewrite_graph(
72
72
return fgraph .outputs [0 ]
73
73
74
74
75
- def is_same_graph_with_merge (var1 , var2 , givens = None ):
75
+ def is_same_graph_with_merge (
76
+ var1 : Variable ,
77
+ var2 : Variable ,
78
+ givens : (
79
+ list [tuple [Variable , Variable ]]
80
+ | tuple [tuple [Variable , Variable ], ...]
81
+ | dict [Variable , Variable ]
82
+ | None
83
+ ) = None ,
84
+ ) -> bool :
76
85
"""
77
86
Merge-based implementation of `pytensor.graph.basic.is_same_graph`.
78
87
@@ -83,8 +92,10 @@ def is_same_graph_with_merge(var1, var2, givens=None):
83
92
84
93
if givens is None :
85
94
givens = {}
95
+ givens = dict (givens )
96
+
86
97
# Copy variables since the MergeOptimizer will modify them.
87
- copied = copy .deepcopy ([ var1 , var2 , givens ] )
98
+ copied = copy .deepcopy (( var1 , var2 , givens ) )
88
99
vars = copied [0 :2 ]
89
100
givens = copied [2 ]
90
101
# Create FunctionGraph.
@@ -113,7 +124,16 @@ def is_same_graph_with_merge(var1, var2, givens=None):
113
124
return o1 is o2
114
125
115
126
116
- def is_same_graph (var1 , var2 , givens = None ):
127
+ def is_same_graph (
128
+ var1 : Variable ,
129
+ var2 : Variable ,
130
+ givens : (
131
+ list [tuple [Variable , Variable ]]
132
+ | tuple [tuple [Variable , Variable ], ...]
133
+ | dict [Variable , Variable ]
134
+ | None
135
+ ) = None ,
136
+ ) -> bool :
117
137
"""
118
138
Return True iff Variables `var1` and `var2` perform the same computation.
119
139
@@ -155,9 +175,7 @@ def is_same_graph(var1, var2, givens=None):
155
175
156
176
if givens is None :
157
177
givens = {}
158
-
159
- if not isinstance (givens , dict ):
160
- givens = dict (givens )
178
+ givens = dict (givens )
161
179
162
180
# Get result from the merge-based function.
163
181
rval1 = is_same_graph_with_merge (var1 = var1 , var2 = var2 , givens = givens )
@@ -174,10 +192,11 @@ def is_same_graph(var1, var2, givens=None):
174
192
in_xs = []
175
193
in_ys = []
176
194
# Compute the sets of all variables found in each computational graph.
177
- inputs_var = list (map (graph_inputs , ([var1 ], [var2 ])))
195
+ inputs_var1 = graph_inputs ([var1 ])
196
+ inputs_var2 = graph_inputs ([var2 ])
178
197
all_vars = [
179
198
set (vars_between (v_i , v_o ))
180
- for v_i , v_o in ((inputs_var [ 0 ] , [var1 ]), (inputs_var [ 1 ] , [var2 ]))
199
+ for v_i , v_o in ((inputs_var1 , [var1 ]), (inputs_var2 , [var2 ]))
181
200
]
182
201
183
202
def in_var (x , k ):
0 commit comments