Skip to content

Commit 6a0c00a

Browse files
committed
Type two functions in graph/rewriting/utils.py
1 parent 56c30e0 commit 6a0c00a

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

pytensor/graph/basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ def explicit_graph_inputs(
986986

987987

988988
def vars_between(
989-
ins: Collection[Variable], outs: Iterable[Variable]
989+
ins: Iterable[Variable], outs: Iterable[Variable]
990990
) -> Generator[Variable, None, None]:
991991
r"""Extract the `Variable`\s within the sub-graph between input and output nodes.
992992
@@ -1006,6 +1006,8 @@ def vars_between(
10061006
10071007
"""
10081008

1009+
ins = set(ins)
1010+
10091011
def expand(r: Variable) -> Iterable[Variable] | None:
10101012
if r.owner and r not in ins:
10111013
return reversed(r.owner.inputs + r.owner.outputs)

pytensor/graph/rewriting/utils.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,16 @@ def rewrite_graph(
7272
return fgraph.outputs[0]
7373

7474

75-
def is_same_graph_with_merge(var1, var2, givens=None):
75+
def is_same_graph_with_merge(
76+
var1: Variable,
77+
var2: Variable,
78+
givens: (
79+
list[tuple[Variable, Variable]]
80+
| tuple[tuple[Variable, Variable], ...]
81+
| dict[Variable, Variable]
82+
| None
83+
) = None,
84+
) -> bool:
7685
"""
7786
Merge-based implementation of `pytensor.graph.basic.is_same_graph`.
7887
@@ -83,8 +92,10 @@ def is_same_graph_with_merge(var1, var2, givens=None):
8392

8493
if givens is None:
8594
givens = {}
95+
givens = dict(givens)
96+
8697
# Copy variables since the MergeOptimizer will modify them.
87-
copied = copy.deepcopy([var1, var2, givens])
98+
copied = copy.deepcopy((var1, var2, givens))
8899
vars = copied[0:2]
89100
givens = copied[2]
90101
# Create FunctionGraph.
@@ -113,7 +124,16 @@ def is_same_graph_with_merge(var1, var2, givens=None):
113124
return o1 is o2
114125

115126

116-
def is_same_graph(var1, var2, givens=None):
127+
def is_same_graph(
128+
var1: Variable,
129+
var2: Variable,
130+
givens: (
131+
list[tuple[Variable, Variable]]
132+
| tuple[tuple[Variable, Variable], ...]
133+
| dict[Variable, Variable]
134+
| None
135+
) = None,
136+
) -> bool:
117137
"""
118138
Return True iff Variables `var1` and `var2` perform the same computation.
119139
@@ -155,9 +175,7 @@ def is_same_graph(var1, var2, givens=None):
155175

156176
if givens is None:
157177
givens = {}
158-
159-
if not isinstance(givens, dict):
160-
givens = dict(givens)
178+
givens = dict(givens)
161179

162180
# Get result from the merge-based function.
163181
rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)
@@ -174,10 +192,11 @@ def is_same_graph(var1, var2, givens=None):
174192
in_xs = []
175193
in_ys = []
176194
# Compute the sets of all variables found in each computational graph.
177-
inputs_var = list(map(graph_inputs, ([var1], [var2])))
195+
inputs_var1 = graph_inputs([var1])
196+
inputs_var2 = graph_inputs([var2])
178197
all_vars = [
179198
set(vars_between(v_i, v_o))
180-
for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
199+
for v_i, v_o in ((inputs_var1, [var1]), (inputs_var2, [var2]))
181200
]
182201

183202
def in_var(x, k):

0 commit comments

Comments
 (0)