Skip to content

Commit 1b67356

Browse files
ferrinericardoV94
authored andcommitted
add truncated_graph_inputs function
1 parent 5fdc130 commit 1b67356

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed

pytensor/graph/basic.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,136 @@ def applys_between(
970970
)
971971

972972

973+
def truncated_graph_inputs(
974+
outputs: Sequence[Variable],
975+
ancestors_to_include: Optional[Collection[Variable]] = None,
976+
) -> List[Variable]:
977+
"""Get the truncate graph inputs.
978+
979+
Unlike :func:`graph_inputs` this function will return
980+
the closest nodes to outputs that do not depend on
981+
``ancestors_to_include``. So given all the returned
982+
variables provided there is no missing node to
983+
compute the output and all nodes are independent
984+
from each other.
985+
986+
Parameters
987+
----------
988+
outputs : Collection[Variable]
989+
Variable to get conditions for
990+
ancestors_to_include : Optional[Collection[Variable]]
991+
Additional ancestors to assume, by default None
992+
993+
Returns
994+
-------
995+
List[Variable]
996+
Variables required to compute ``outputs``
997+
998+
Examples
999+
--------
1000+
The returned nodes marked in (parenthesis), ancestors nodes are ``c``, output nodes are ``o``
1001+
1002+
* No ancestors to include
1003+
1004+
.. code-block::
1005+
1006+
n - n - (o)
1007+
1008+
* One ancestors to include
1009+
1010+
.. code-block::
1011+
1012+
n - (c) - o
1013+
1014+
* Two ancestors to include where on depends on another, both returned
1015+
1016+
.. code-block::
1017+
1018+
(c) - (c) - o
1019+
1020+
* Additional nodes are present
1021+
1022+
.. code-block::
1023+
1024+
(c) - n - o
1025+
n - (n) -'
1026+
1027+
* Disconnected ancestors to include not returned
1028+
1029+
.. code-block::
1030+
1031+
(c) - n - o
1032+
c
1033+
1034+
* Disconnected output is present and returned
1035+
1036+
.. code-block::
1037+
1038+
(c) - (c) - o
1039+
(o)
1040+
1041+
* ancestors to include that include itself adds itself
1042+
1043+
.. code-block::
1044+
1045+
n - (c) - (o/c)
1046+
1047+
"""
1048+
# simple case, no additional ancestors to include
1049+
truncated_inputs = list()
1050+
# blockers have known independent nodes and ancestors to include
1051+
candidates = list(outputs)
1052+
if not ancestors_to_include: # None or empty
1053+
# just filter out unique variables
1054+
for node in candidates:
1055+
if node not in truncated_inputs:
1056+
truncated_inputs.append(node)
1057+
# no more actions are needed
1058+
return truncated_inputs
1059+
blockers: Set[Variable] = set(ancestors_to_include)
1060+
# enforce O(1) check for node in ancestors to include
1061+
ancestors_to_include = blockers.copy()
1062+
1063+
while candidates:
1064+
# on any new candidate
1065+
node = candidates.pop()
1066+
# check if the node is independent, never go above blockers
1067+
# blockers are independent nodes and ancestors to include
1068+
if node in ancestors_to_include:
1069+
# The case where node is in ancestors to include so we check if it depends on others
1070+
# it should be removed from the blockers to check against the rest
1071+
dependent = variable_depends_on(node, blockers - {node})
1072+
# ancestors to include that are present in the graph (not disconnected)
1073+
# should be added to truncated_inputs
1074+
truncated_inputs.append(node)
1075+
if dependent:
1076+
# if the ancestors to include is still dependent we need to go above,
1077+
# the search is not yet finished
1078+
# the node _has_ to have owner to be dependent
1079+
# so we do not check it
1080+
# and populate search to go above
1081+
# owner can never be None for a dependent node
1082+
candidates.extend(node.owner.inputs)
1083+
else:
1084+
# A regular node to check
1085+
dependent = variable_depends_on(node, blockers)
1086+
# all regular nodes fall to blockes
1087+
# 1. it is dependent - further search irrelevant
1088+
# 2. it is independent - the search node is inside the closure
1089+
blockers.add(node)
1090+
# if we've found an independent node and it is not in blockers so far
1091+
# it is a new indepenent node not present in ancestors to include
1092+
if not dependent:
1093+
# we've found an independent node
1094+
# do not search beyond
1095+
truncated_inputs.append(node)
1096+
else:
1097+
# populate search otherwise
1098+
# owner can never be None for a dependent node
1099+
candidates.extend(node.owner.inputs)
1100+
return truncated_inputs
1101+
1102+
9731103
def clone(
9741104
inputs: List[Variable],
9751105
outputs: List[Variable],

tests/graph/test_basic.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
io_toposort,
2424
list_of_nodes,
2525
orphans_between,
26+
truncated_graph_inputs,
2627
variable_depends_on,
2728
vars_between,
2829
walk,
@@ -695,3 +696,56 @@ def test_variable_depends_on():
695696
assert not variable_depends_on(y, [y2])
696697
assert variable_depends_on(y, [y])
697698

699+
700+
def test_truncated_graph_inputs():
701+
"""
702+
* No conditions
703+
n - n - (o)
704+
705+
* One condition
706+
n - (c) - o
707+
708+
* Two conditions where on depends on another, both returned
709+
(c) - (c) - o
710+
711+
* Additional nodes are present
712+
(c) - n - o
713+
n - (n) -'
714+
715+
* Disconnected condition not returned
716+
(c) - n - o
717+
c
718+
719+
* Disconnected output is present and returned
720+
(c) - (c) - o
721+
(o)
722+
723+
* Condition on itself adds itself
724+
n - (c) - (o/c)
725+
"""
726+
x = MyVariable(1)
727+
x.name = "x"
728+
y = MyVariable(1)
729+
y.name = "y"
730+
z = MyVariable(1)
731+
z.name = "z"
732+
x2 = MyOp(x)
733+
x2.name = "x2"
734+
y2 = MyOp(y, x2)
735+
y2.name = "y2"
736+
o = MyOp(y2)
737+
o2 = MyOp(o)
738+
# No conditions
739+
assert truncated_graph_inputs([o]) == [o]
740+
# One condition
741+
assert truncated_graph_inputs([o2], [y2]) == [y2]
742+
# Condition on itself adds itself
743+
assert truncated_graph_inputs([o], [y2, o]) == [o, y2]
744+
# Two conditions where on depends on another, both returned
745+
assert truncated_graph_inputs([o2], [y2, o]) == [o, y2]
746+
# Additional nodes are present
747+
assert truncated_graph_inputs([o], [y]) == [x2, y]
748+
# Disconnected condition
749+
assert truncated_graph_inputs([o2], [y2, z]) == [y2]
750+
# Disconnected output is present
751+
assert truncated_graph_inputs([o2, z], [y2]) == [z, y2]

0 commit comments

Comments
 (0)