Skip to content

Commit 066307f

Browse files
committed
Faster graph traversal functions
* Avoid reversing inputs as we traverse graph * Simplify io_toposort without ordering (and refactor into its own function) * Removes client side-effect on previous toposort functions * Remove duplicated logic across methods
1 parent f1a2ac6 commit 066307f

File tree

16 files changed

+516
-438
lines changed

16 files changed

+516
-438
lines changed

pytensor/graph/basic.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from collections.abc import (
66
Hashable,
77
Iterable,
8-
Reversible,
98
Sequence,
109
)
1110
from copy import copy
@@ -961,7 +960,7 @@ def clone_node_and_cache(
961960

962961
def clone_get_equiv(
963962
inputs: Iterable[Variable],
964-
outputs: Reversible[Variable],
963+
outputs: Iterable[Variable],
965964
copy_inputs: bool = True,
966965
copy_orphans: bool = True,
967966
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]
@@ -1002,7 +1001,7 @@ def clone_get_equiv(
10021001
Keywords passed to `Apply.clone_with_new_inputs`.
10031002
10041003
"""
1005-
from pytensor.graph.traversal import io_toposort
1004+
from pytensor.graph.traversal import toposort
10061005

10071006
if memo is None:
10081007
memo = {}
@@ -1018,7 +1017,7 @@ def clone_get_equiv(
10181017
memo.setdefault(input, input)
10191018

10201019
# go through the inputs -> outputs graph cloning as we go
1021-
for apply in io_toposort(inputs, outputs):
1020+
for apply in toposort(outputs, blockers=inputs):
10221021
for input in apply.inputs:
10231022
if input not in memo:
10241023
if not isinstance(input, Constant) and copy_orphans:

pytensor/graph/features.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pytensor
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Variable
13-
from pytensor.graph.traversal import io_toposort
13+
from pytensor.graph.traversal import toposort
1414
from pytensor.graph.utils import InconsistencyError
1515

1616

@@ -340,11 +340,11 @@ def clone(self):
340340

341341
class Bookkeeper(Feature):
342342
def on_attach(self, fgraph):
343-
for node in io_toposort(fgraph.inputs, fgraph.outputs):
343+
for node in toposort(fgraph.outputs):
344344
self.on_import(fgraph, node, "on_attach")
345345

346346
def on_detach(self, fgraph):
347-
for node in io_toposort(fgraph.inputs, fgraph.outputs):
347+
for node in toposort(fgraph.outputs):
348348
self.on_prune(fgraph, node, "Bookkeeper.detach")
349349

350350

pytensor/graph/fg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from pytensor.graph.traversal import (
2020
applys_between,
2121
graph_inputs,
22-
io_toposort,
22+
toposort,
23+
toposort_with_orderings,
2324
vars_between,
2425
)
2526
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
@@ -366,7 +367,7 @@ def import_node(
366367
# new nodes, so we use all variables we know of as if they were the
367368
# input set. (The functions in the graph module only use the input set
368369
# to know where to stop going down.)
369-
new_nodes = io_toposort(self.variables, apply_node.outputs)
370+
new_nodes = tuple(toposort(apply_node.outputs, blockers=self.variables))
370371

371372
if check:
372373
for node in new_nodes:
@@ -759,7 +760,7 @@ def toposort(self) -> list[Apply]:
759760
# No sorting is necessary
760761
return list(self.apply_nodes)
761762

762-
return io_toposort(self.inputs, self.outputs, self.orderings())
763+
return list(toposort_with_orderings(self.outputs, orderings=self.orderings()))
763764

764765
def orderings(self) -> dict[Apply, list[Apply]]:
765766
"""Return a map of node to node evaluation dependencies.

pytensor/graph/replace.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
)
1111
from pytensor.graph.fg import FunctionGraph
1212
from pytensor.graph.op import Op
13-
from pytensor.graph.traversal import io_toposort, truncated_graph_inputs
13+
from pytensor.graph.traversal import (
14+
toposort,
15+
truncated_graph_inputs,
16+
)
1417

1518

1619
ReplaceTypes = Iterable[tuple[Variable, Variable]] | dict[Variable, Variable]
@@ -295,7 +298,7 @@ def vectorize_graph(
295298
new_inputs = [replace.get(inp, inp) for inp in inputs]
296299

297300
vect_vars = dict(zip(inputs, new_inputs, strict=True))
298-
for node in io_toposort(inputs, seq_outputs):
301+
for node in toposort(seq_outputs, blockers=inputs):
299302
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
300303
vect_node = vectorize_node(node, *vect_inputs)
301304
for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True):

pytensor/graph/rewriting/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pytensor.graph.fg import FunctionGraph, Output
2828
from pytensor.graph.op import Op
2929
from pytensor.graph.rewriting.unify import OpPattern, Var, convert_strs_to_vars
30-
from pytensor.graph.traversal import applys_between, io_toposort, vars_between
30+
from pytensor.graph.traversal import applys_between, toposort, vars_between
3131
from pytensor.graph.utils import AssocList, InconsistencyError
3232
from pytensor.misc.ordered_set import OrderedSet
3333
from pytensor.utils import flatten
@@ -2010,7 +2010,7 @@ def apply(self, fgraph, start_from=None):
20102010
callback_before = fgraph.execute_callbacks_time
20112011
nb_nodes_start = len(fgraph.apply_nodes)
20122012
t0 = time.perf_counter()
2013-
q = deque(io_toposort(fgraph.inputs, start_from))
2013+
q = deque(toposort(start_from))
20142014
io_t = time.perf_counter() - t0
20152015

20162016
def importer(node):
@@ -2341,7 +2341,7 @@ def apply_cleanup(profs_dict):
23412341
changed |= apply_cleanup(iter_cleanup_sub_profs)
23422342

23432343
topo_t0 = time.perf_counter()
2344-
q = deque(io_toposort(fgraph.inputs, start_from))
2344+
q = deque(toposort(start_from))
23452345
io_toposort_timing.append(time.perf_counter() - topo_t0)
23462346

23472347
nb_nodes.append(len(q))

0 commit comments

Comments
 (0)