Skip to content

Commit 8483b04

Browse files
committed
Move io_connection_pattern to graph/op.py
1 parent c6dae89 commit 8483b04

File tree

6 files changed

+73
-75
lines changed

6 files changed

+73
-75
lines changed

pytensor/compile/builders.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
NominalVariable,
1818
Variable,
1919
graph_inputs,
20-
io_connection_pattern,
2120
)
2221
from pytensor.graph.fg import FunctionGraph
2322
from pytensor.graph.null_type import NullType
24-
from pytensor.graph.op import HasInnerGraph, Op
23+
from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern
2524
from pytensor.graph.replace import clone_replace
2625
from pytensor.graph.utils import MissingInputError
2726

pytensor/graph/basic.py

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,71 +1633,6 @@ def default_node_formatter(op, argstrings):
16331633
return f"{op.op}({', '.join(argstrings)})"
16341634

16351635

1636-
def io_connection_pattern(inputs, outputs):
1637-
"""Return the connection pattern of a subgraph defined by given inputs and outputs."""
1638-
inner_nodes = io_toposort(inputs, outputs)
1639-
1640-
# Initialize 'connect_pattern_by_var' by establishing each input as
1641-
# connected only to itself
1642-
connect_pattern_by_var = {}
1643-
nb_inputs = len(inputs)
1644-
1645-
for i in range(nb_inputs):
1646-
input = inputs[i]
1647-
inp_connection_pattern = [i == j for j in range(nb_inputs)]
1648-
connect_pattern_by_var[input] = inp_connection_pattern
1649-
1650-
# Iterate through the nodes used to produce the outputs from the
1651-
# inputs and, for every node, infer their connection pattern to
1652-
# every input from the connection patterns of their parents.
1653-
for n in inner_nodes:
1654-
# Get the connection pattern of the inner node's op. If the op
1655-
# does not define a connection_pattern method, assume that
1656-
# every node output is connected to every node input
1657-
try:
1658-
op_connection_pattern = n.op.connection_pattern(n)
1659-
except AttributeError:
1660-
op_connection_pattern = [[True] * len(n.outputs)] * len(n.inputs)
1661-
1662-
# For every output of the inner node, figure out which inputs it
1663-
# is connected to by combining the connection pattern of the inner
1664-
# node and the connection patterns of the inner node's inputs.
1665-
for out_idx in range(len(n.outputs)):
1666-
out = n.outputs[out_idx]
1667-
out_connection_pattern = [False] * nb_inputs
1668-
1669-
for inp_idx in range(len(n.inputs)):
1670-
inp = n.inputs[inp_idx]
1671-
1672-
if inp in connect_pattern_by_var:
1673-
inp_connection_pattern = connect_pattern_by_var[inp]
1674-
1675-
# If the node output is connected to the node input, it
1676-
# means it is connected to every inner input that the
1677-
# node inputs is connected to
1678-
if op_connection_pattern[inp_idx][out_idx]:
1679-
out_connection_pattern = [
1680-
out_connection_pattern[i] or inp_connection_pattern[i]
1681-
for i in range(nb_inputs)
1682-
]
1683-
1684-
# Store the connection pattern of the node output
1685-
connect_pattern_by_var[out] = out_connection_pattern
1686-
1687-
# Obtain the global connection pattern by combining the
1688-
# connection patterns of the individual outputs
1689-
global_connection_pattern = [[] for o in range(len(inputs))]
1690-
for out in outputs:
1691-
out_connection_pattern = connect_pattern_by_var.get(out)
1692-
if out_connection_pattern is None:
1693-
# the output is completely isolated from inputs
1694-
out_connection_pattern = [False] * len(inputs)
1695-
for i in range(len(inputs)):
1696-
global_connection_pattern[i].append(out_connection_pattern[i])
1697-
1698-
return global_connection_pattern
1699-
1700-
17011636
def op_as_string(
17021637
i, op, leaf_formatter=default_leaf_formatter, node_formatter=default_node_formatter
17031638
):

pytensor/graph/op.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import pytensor
1515
from pytensor.configdefaults import config
16-
from pytensor.graph.basic import Apply, Variable
16+
from pytensor.graph.basic import Apply, Variable, io_toposort
1717
from pytensor.graph.utils import (
1818
MetaObject,
1919
TestValueError,
@@ -753,3 +753,68 @@ def get_test_values(*args: Variable) -> Any | list[Any]:
753753
return rval
754754

755755
return [tuple(rval)]
756+
757+
758+
def io_connection_pattern(inputs, outputs):
759+
"""Return the connection pattern of a subgraph defined by given inputs and outputs."""
760+
inner_nodes = io_toposort(inputs, outputs)
761+
762+
# Initialize 'connect_pattern_by_var' by establishing each input as
763+
# connected only to itself
764+
connect_pattern_by_var = {}
765+
nb_inputs = len(inputs)
766+
767+
for i in range(nb_inputs):
768+
input = inputs[i]
769+
inp_connection_pattern = [i == j for j in range(nb_inputs)]
770+
connect_pattern_by_var[input] = inp_connection_pattern
771+
772+
# Iterate through the nodes used to produce the outputs from the
773+
# inputs and, for every node, infer their connection pattern to
774+
# every input from the connection patterns of their parents.
775+
for n in inner_nodes:
776+
# Get the connection pattern of the inner node's op. If the op
777+
# does not define a connection_pattern method, assume that
778+
# every node output is connected to every node input
779+
try:
780+
op_connection_pattern = n.op.connection_pattern(n)
781+
except AttributeError:
782+
op_connection_pattern = [[True] * len(n.outputs)] * len(n.inputs)
783+
784+
# For every output of the inner node, figure out which inputs it
785+
# is connected to by combining the connection pattern of the inner
786+
# node and the connection patterns of the inner node's inputs.
787+
for out_idx in range(len(n.outputs)):
788+
out = n.outputs[out_idx]
789+
out_connection_pattern = [False] * nb_inputs
790+
791+
for inp_idx in range(len(n.inputs)):
792+
inp = n.inputs[inp_idx]
793+
794+
if inp in connect_pattern_by_var:
795+
inp_connection_pattern = connect_pattern_by_var[inp]
796+
797+
# If the node output is connected to the node input, it
798+
# means it is connected to every inner input that the
799+
# node inputs is connected to
800+
if op_connection_pattern[inp_idx][out_idx]:
801+
out_connection_pattern = [
802+
out_connection_pattern[i] or inp_connection_pattern[i]
803+
for i in range(nb_inputs)
804+
]
805+
806+
# Store the connection pattern of the node output
807+
connect_pattern_by_var[out] = out_connection_pattern
808+
809+
# Obtain the global connection pattern by combining the
810+
# connection patterns of the individual outputs
811+
global_connection_pattern = [[] for o in range(len(inputs))]
812+
for out in outputs:
813+
out_connection_pattern = connect_pattern_by_var.get(out)
814+
if out_connection_pattern is None:
815+
# the output is completely isolated from inputs
816+
out_connection_pattern = [False] * len(inputs)
817+
for i in range(len(inputs)):
818+
global_connection_pattern[i].append(out_connection_pattern[i])
819+
820+
return global_connection_pattern

pytensor/scan/op.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,9 @@
6868
Variable,
6969
equal_computations,
7070
graph_inputs,
71-
io_connection_pattern,
7271
)
7372
from pytensor.graph.features import NoOutputFromInplace
74-
from pytensor.graph.op import HasInnerGraph, Op
73+
from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern
7574
from pytensor.graph.replace import clone_replace
7675
from pytensor.graph.type import HasShape
7776
from pytensor.graph.utils import InconsistencyError, MissingInputError

tests/graph/test_basic.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -584,11 +584,6 @@ def test_apply_depends_on():
584584
assert apply_depends_on(o3.owner, [o1.owner, o2.owner])
585585

586586

587-
@pytest.mark.xfail(reason="Not implemented")
588-
def test_io_connection_pattern():
589-
raise AssertionError()
590-
591-
592587
def test_get_var_by_name():
593588
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
594589
o1 = MyOp(r1, r2)

tests/graph/test_op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,8 @@ def perform(self, node, inputs, outputs):
275275

276276
res_nameless = single_op(x)
277277
assert res_nameless.name is None
278+
279+
280+
@pytest.mark.xfail(reason="Not implemented")
281+
def test_io_connection_pattern():
282+
raise AssertionError()

0 commit comments

Comments
 (0)