|
13 | 13 |
|
14 | 14 | import pytensor
|
15 | 15 | from pytensor.configdefaults import config
|
16 |
| -from pytensor.graph.basic import Apply, Variable |
| 16 | +from pytensor.graph.basic import Apply, Variable, io_toposort |
17 | 17 | from pytensor.graph.utils import (
|
18 | 18 | MetaObject,
|
19 | 19 | TestValueError,
|
@@ -753,3 +753,68 @@ def get_test_values(*args: Variable) -> Any | list[Any]:
|
753 | 753 | return rval
|
754 | 754 |
|
755 | 755 | return [tuple(rval)]
|
| 756 | + |
| 757 | + |
| 758 | +def io_connection_pattern(inputs, outputs): |
| 759 | + """Return the connection pattern of a subgraph defined by given inputs and outputs.""" |
| 760 | + inner_nodes = io_toposort(inputs, outputs) |
| 761 | + |
| 762 | + # Initialize 'connect_pattern_by_var' by establishing each input as |
| 763 | + # connected only to itself |
| 764 | + connect_pattern_by_var = {} |
| 765 | + nb_inputs = len(inputs) |
| 766 | + |
| 767 | + for i in range(nb_inputs): |
| 768 | + input = inputs[i] |
| 769 | + inp_connection_pattern = [i == j for j in range(nb_inputs)] |
| 770 | + connect_pattern_by_var[input] = inp_connection_pattern |
| 771 | + |
| 772 | + # Iterate through the nodes used to produce the outputs from the |
| 773 | + # inputs and, for every node, infer their connection pattern to |
| 774 | + # every input from the connection patterns of their parents. |
| 775 | + for n in inner_nodes: |
| 776 | + # Get the connection pattern of the inner node's op. If the op |
| 777 | + # does not define a connection_pattern method, assume that |
| 778 | + # every node output is connected to every node input |
| 779 | + try: |
| 780 | + op_connection_pattern = n.op.connection_pattern(n) |
| 781 | + except AttributeError: |
| 782 | + op_connection_pattern = [[True] * len(n.outputs)] * len(n.inputs) |
| 783 | + |
| 784 | + # For every output of the inner node, figure out which inputs it |
| 785 | + # is connected to by combining the connection pattern of the inner |
| 786 | + # node and the connection patterns of the inner node's inputs. |
| 787 | + for out_idx in range(len(n.outputs)): |
| 788 | + out = n.outputs[out_idx] |
| 789 | + out_connection_pattern = [False] * nb_inputs |
| 790 | + |
| 791 | + for inp_idx in range(len(n.inputs)): |
| 792 | + inp = n.inputs[inp_idx] |
| 793 | + |
| 794 | + if inp in connect_pattern_by_var: |
| 795 | + inp_connection_pattern = connect_pattern_by_var[inp] |
| 796 | + |
| 797 | + # If the node output is connected to the node input, it |
| 798 | + # means it is connected to every inner input that the |
| 799 | + # node inputs is connected to |
| 800 | + if op_connection_pattern[inp_idx][out_idx]: |
| 801 | + out_connection_pattern = [ |
| 802 | + out_connection_pattern[i] or inp_connection_pattern[i] |
| 803 | + for i in range(nb_inputs) |
| 804 | + ] |
| 805 | + |
| 806 | + # Store the connection pattern of the node output |
| 807 | + connect_pattern_by_var[out] = out_connection_pattern |
| 808 | + |
| 809 | + # Obtain the global connection pattern by combining the |
| 810 | + # connection patterns of the individual outputs |
| 811 | + global_connection_pattern = [[] for o in range(len(inputs))] |
| 812 | + for out in outputs: |
| 813 | + out_connection_pattern = connect_pattern_by_var.get(out) |
| 814 | + if out_connection_pattern is None: |
| 815 | + # the output is completely isolated from inputs |
| 816 | + out_connection_pattern = [False] * len(inputs) |
| 817 | + for i in range(len(inputs)): |
| 818 | + global_connection_pattern[i].append(out_connection_pattern[i]) |
| 819 | + |
| 820 | + return global_connection_pattern |
0 commit comments