Skip to content

Commit aa7e4d6

Browse files
committed
Speedup FunctionGraph methods
1 parent 066307f commit aa7e4d6

File tree

1 file changed

+44
-69
lines changed

1 file changed

+44
-69
lines changed

pytensor/graph/fg.py

Lines changed: 44 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
vars_between,
2525
)
2626
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
27-
from pytensor.misc.ordered_set import OrderedSet
2827

2928

3029
ClientType = tuple[Apply, int]
@@ -133,7 +132,6 @@ def __init__(
133132
features = []
134133

135134
self._features: list[Feature] = []
136-
137135
# All apply nodes in the subgraph defined by inputs and
138136
# outputs are cached in this field
139137
self.apply_nodes: set[Apply] = set()
@@ -161,7 +159,8 @@ def __init__(
161159
"input's owner or use graph.clone."
162160
)
163161

164-
self.add_input(in_var, check=False)
162+
self.inputs.append(in_var)
163+
self.clients.setdefault(in_var, [])
165164

166165
for output in outputs:
167166
self.add_output(output, reason="init")
@@ -189,16 +188,6 @@ def add_input(self, var: Variable, check: bool = True) -> None:
189188
return
190189

191190
self.inputs.append(var)
192-
self.setup_var(var)
193-
194-
def setup_var(self, var: Variable) -> None:
195-
"""Set up a variable so it belongs to this `FunctionGraph`.
196-
197-
Parameters
198-
----------
199-
var : pytensor.graph.basic.Variable
200-
201-
"""
202191
self.clients.setdefault(var, [])
203192

204193
def get_clients(self, var: Variable) -> list[ClientType]:
@@ -322,10 +311,11 @@ def import_var(
322311
323312
"""
324313
# Imports the owners of the variables
325-
if var.owner and var.owner not in self.apply_nodes:
326-
self.import_node(var.owner, reason=reason, import_missing=import_missing)
314+
apply = var.owner
315+
if apply is not None and apply not in self.apply_nodes:
316+
self.import_node(apply, reason=reason, import_missing=import_missing)
327317
elif (
328-
var.owner is None
318+
apply is None
329319
and not isinstance(var, AtomicVariable)
330320
and var not in self.inputs
331321
):
@@ -336,10 +326,11 @@ def import_var(
336326
f"Computation graph contains a NaN. {var.type.why_null}"
337327
)
338328
if import_missing:
339-
self.add_input(var)
329+
self.inputs.append(var)
330+
self.clients.setdefault(var, [])
340331
else:
341332
raise MissingInputError(f"Undeclared input: {var}", variable=var)
342-
self.setup_var(var)
333+
self.clients.setdefault(var, [])
343334
self.variables.add(var)
344335

345336
def import_node(
@@ -356,29 +347,29 @@ def import_node(
356347
apply_node : Apply
357348
The node to be imported.
358349
check : bool
359-
Check that the inputs for the imported nodes are also present in
360-
the `FunctionGraph`.
350+
Check that the inputs for the imported nodes are also present in the `FunctionGraph`.
361351
reason : str
362352
The name of the optimization or operation in progress.
363353
import_missing : bool
364354
Add missing inputs instead of raising an exception.
365355
"""
366356
# We import the nodes in topological order. We only are interested in
367-
# new nodes, so we use all variables we know of as if they were the
368-
# input set. (The functions in the graph module only use the input set
369-
# to know where to stop going down.)
370-
new_nodes = tuple(toposort(apply_node.outputs, blockers=self.variables))
371-
372-
if check:
373-
for node in new_nodes:
357+
# new nodes, so we use all nodes we know of as inputs to interrupt the toposort
358+
self_variables = self.variables
359+
self_clients = self.clients
360+
self_apply_nodes = self.apply_nodes
361+
self_inputs = self.inputs
362+
for node in toposort(apply_node.outputs, blockers=self_variables):
363+
if check:
374364
for var in node.inputs:
375365
if (
376366
var.owner is None
377367
and not isinstance(var, AtomicVariable)
378-
and var not in self.inputs
368+
and var not in self_inputs
379369
):
380370
if import_missing:
381-
self.add_input(var)
371+
self_inputs.append(var)
372+
self_clients.setdefault(var, [])
382373
else:
383374
error_msg = (
384375
f"Input {node.inputs.index(var)} ({var})"
@@ -390,20 +381,20 @@ def import_node(
390381
)
391382
raise MissingInputError(error_msg, variable=var)
392383

393-
for node in new_nodes:
394-
assert node not in self.apply_nodes
395-
self.apply_nodes.add(node)
396-
if not hasattr(node.tag, "imported_by"):
397-
node.tag.imported_by = []
398-
node.tag.imported_by.append(str(reason))
384+
self_apply_nodes.add(node)
385+
tag = node.tag
386+
if not hasattr(tag, "imported_by"):
387+
tag.imported_by = [str(reason)]
388+
else:
389+
tag.imported_by.append(str(reason))
399390
for output in node.outputs:
400-
self.setup_var(output)
401-
self.variables.add(output)
402-
for i, input in enumerate(node.inputs):
403-
if input not in self.variables:
404-
self.setup_var(input)
405-
self.variables.add(input)
406-
self.add_client(input, (node, i))
391+
self_clients.setdefault(output, [])
392+
self_variables.add(output)
393+
for i, inp in enumerate(node.inputs):
394+
if inp not in self_variables:
395+
self_clients.setdefault(inp, [])
396+
self_variables.add(inp)
397+
self_clients[inp].append((node, i))
407398
self.execute_callbacks("on_import", node, reason)
408399

409400
def change_node_input(
@@ -457,7 +448,7 @@ def change_node_input(
457448
self.outputs[node.op.idx] = new_var
458449

459450
self.import_var(new_var, reason=reason, import_missing=import_missing)
460-
self.add_client(new_var, (node, i))
451+
self.clients[new_var].append((node, i))
461452
self.remove_client(r, (node, i), reason=reason)
462453
# Precondition: the substitution is semantically valid However it may
463454
# introduce cycles to the graph, in which case the transaction will be
@@ -756,10 +747,6 @@ def toposort(self) -> list[Apply]:
756747
:meth:`FunctionGraph.orderings`.
757748
758749
"""
759-
if len(self.apply_nodes) < 2:
760-
# No sorting is necessary
761-
return list(self.apply_nodes)
762-
763750
return list(toposort_with_orderings(self.outputs, orderings=self.orderings()))
764751

765752
def orderings(self) -> dict[Apply, list[Apply]]:
@@ -779,29 +766,17 @@ def orderings(self) -> dict[Apply, list[Apply]]:
779766
take care of computing the dependencies by itself.
780767
781768
"""
782-
assert isinstance(self._features, list)
783-
all_orderings: list[dict] = []
769+
all_orderings: list[dict] = [
770+
orderings
771+
for feature in self._features
772+
if (
773+
hasattr(feature, "orderings") and (orderings := feature.orderings(self))
774+
)
775+
]
784776

785-
for feature in self._features:
786-
if hasattr(feature, "orderings"):
787-
orderings = feature.orderings(self)
788-
if not isinstance(orderings, dict):
789-
raise TypeError(
790-
"Non-deterministic return value from "
791-
+ str(feature.orderings)
792-
+ ". Nondeterministic object is "
793-
+ str(orderings)
794-
)
795-
if len(orderings) > 0:
796-
all_orderings.append(orderings)
797-
for node, prereqs in orderings.items():
798-
if not isinstance(prereqs, list | OrderedSet):
799-
raise TypeError(
800-
"prereqs must be a type with a "
801-
"deterministic iteration order, or toposort "
802-
" will be non-deterministic."
803-
)
804-
if len(all_orderings) == 1:
777+
if not all_orderings:
778+
return {}
779+
elif len(all_orderings) == 1:
805780
# If there is only 1 ordering, we reuse it directly.
806781
return all_orderings[0].copy()
807782
else:

0 commit comments

Comments
 (0)