Skip to content

Commit 046c469

Browse files
committed
avoid sorting outputs of last subgraph
1 parent cb1ae56 commit 046c469

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

graph_net/torch/decompose_util.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,16 @@ def convert_to_submodules_graph(
3737
split_positions = [
3838
max(0, min(pos, len(submodules_body_nodes))) for pos in split_positions
3939
]
40-
range_idx2submodule_body_nodes = [
41-
submodules_body_nodes[start:end]
40+
range_idx2range = [
41+
(start, end)
4242
for i in range(len(split_positions) - 1)
4343
for start in [split_positions[i]]
4444
for end in [split_positions[i + 1]]
4545
if end > start
4646
]
47+
range_idx2submodule_body_nodes = [
48+
submodules_body_nodes[start:end] for start, end in range_idx2range
49+
]
4750

4851
def get_body_nodes(range_idx):
4952
return range_idx2submodule_body_nodes[range_idx]
@@ -101,6 +104,9 @@ def get_input_nodes(range_idx):
101104
return sorted(submodule_input_nodes, key=sort_key)
102105

103106
def get_output_nodes(range_idx):
107+
end = range_idx2range[range_idx][1]
108+
if end >= len(submodules_body_nodes):
109+
return submodule_output_nodes
104110
return sorted(submodule_output_nodes, key=sort_key)
105111

106112
submodule_name = (

0 commit comments

Comments
 (0)