Skip to content

Commit cb1ae56

Browse files
committed
sort submodule inputs and outputs
1 parent 528d46c commit cb1ae56

File tree

3 files changed

+44
-34
lines changed

3 files changed

+44
-34
lines changed

graph_net/test/chain_naive_graph_decomposer_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ read -r -d '' extractor_config_json_str <<EOF
99
{
1010
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
1111
"custom_extractor_config": {
12-
"output_dir": "/tmp/naive_decompose_workspace",
12+
"output_dir": "/tmp/chain_naive_decompose_workspace",
1313
"split_positions": [8, 16, 32],
1414
"group_head_and_tail": true,
1515
"chain_style": true

graph_net/test/naive_graph_decomposer_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ read -r -d '' extractor_config_json_str <<EOF
1010
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
1111
"custom_extractor_config": {
1212
"output_dir": "/tmp/naive_decompose_workspace",
13-
"split_positions": [8, 32],
13+
"split_positions": [8, 16, 32],
1414
"group_head_and_tail": true,
1515
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
1616
"filter_config": {}

graph_net/torch/decompose_util.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,23 @@
66

77

88
def convert_to_submodules_graph(
9-
original_gm: torch.fx.GraphModule,
9+
gm: torch.fx.GraphModule,
1010
split_positions: list[int],
1111
submodule_hook=None,
1212
submodule_name_prefix="extracted_submodule",
1313
chain_style=False,
1414
group_head_and_tail=True,
1515
):
1616
"""
17-
chain_style=True: decompose original_gm into g0 * g1 * g2 * g3
17+
chain_style=True: decompose gm into g0 * g1 * g2 * g3
1818
"""
19-
original_gm = copy.deepcopy(original_gm)
19+
gm = copy.deepcopy(gm)
2020
num_placeholders = len(
21-
[node for node in original_gm.graph.nodes if node.op == "placeholder"]
21+
[node for node in gm.graph.nodes if node.op == "placeholder"]
2222
)
2323
submodules_body_nodes = [
2424
node
25-
for node in original_gm.graph.nodes
25+
for node in gm.graph.nodes
2626
if node.op
2727
not in {
2828
"placeholder",
@@ -54,20 +54,20 @@ def get_name2sub_submodule():
5454
)
5555
return {
5656
name: module
57-
for name, module in original_gm.named_modules()
57+
for name, module in gm.named_modules()
5858
if name in used_module_names
5959
}
6060

6161
def get_start_node_idx(range_idx):
6262
start_node = get_body_nodes(range_idx)[0]
63-
for i, node in enumerate(original_gm.graph.nodes):
63+
for i, node in enumerate(gm.graph.nodes):
6464
if node == start_node:
6565
return i
6666
raise NotImplementedError("Dead code.")
6767

6868
def get_end_node_idx(range_idx):
6969
last_node = get_body_nodes(range_idx)[-1]
70-
for i, node in enumerate(original_gm.graph.nodes):
70+
for i, node in enumerate(gm.graph.nodes):
7171
if node == last_node:
7272
return i + 1
7373
raise NotImplementedError("Dead code.")
@@ -78,23 +78,30 @@ def print_submodule_call(prompt, gm):
7878
]
7979
print(f"{prompt} ", submodule_call_stmts)
8080

81+
new_node2original_node = {}
82+
for node in gm.graph.nodes:
83+
new_node2original_node[node] = node
84+
85+
def sort_key(node):
86+
return new_node2original_node[node].name
87+
8188
for range_idx in range(len(range_idx2submodule_body_nodes)):
8289
(
8390
submodule_input_nodes,
8491
submodule_output_nodes,
8592
identity_nodes,
8693
) = _get_submodule_inputs_and_outputs(
87-
original_gm=original_gm,
94+
gm=gm,
8895
start_node_idx=get_start_node_idx(range_idx),
8996
end_node_idx=get_end_node_idx(range_idx),
9097
chain_style=chain_style,
9198
)
9299

93100
def get_input_nodes(range_idx):
94-
return submodule_input_nodes
101+
return sorted(submodule_input_nodes, key=sort_key)
95102

96103
def get_output_nodes(range_idx):
97-
return submodule_output_nodes
104+
return sorted(submodule_output_nodes, key=sort_key)
98105

99106
submodule_name = (
100107
f"{submodule_name_prefix}_{range_idx}"
@@ -107,7 +114,8 @@ def get_output_nodes(range_idx):
107114

108115
# Add placeholder nodes for inputs
109116
for original_node in get_input_nodes(range_idx):
110-
new_node = new_graph.placeholder(original_node.name)
117+
name = new_node2original_node[original_node].name
118+
new_node = new_graph.placeholder(name)
111119
node_map[original_node] = new_node
112120

113121
# Copy body nodes
@@ -116,9 +124,9 @@ def get_output_nodes(range_idx):
116124
node_map[original_node] = new_node
117125

118126
# Add output nodes
119-
output_args = []
120-
for original_node in get_output_nodes(range_idx):
121-
output_args.append(node_map[original_node])
127+
output_args = [
128+
node_map[original_node] for original_node in get_output_nodes(range_idx)
129+
]
122130
new_graph.output(tuple(output_args))
123131

124132
# Create the new GraphModule
@@ -127,15 +135,15 @@ def get_output_nodes(range_idx):
127135
if submodule_hook is not None:
128136
new_sub_module = submodule_hook(new_sub_module, range_idx)
129137
# Replace with submodule node
130-
original_gm.add_submodule(submodule_name, new_sub_module)
131-
with original_gm.graph.inserting_after(get_body_nodes(range_idx)[-1]):
132-
submodule_node = original_gm.graph.call_module(
138+
gm.add_submodule(submodule_name, new_sub_module)
139+
with gm.graph.inserting_after(get_body_nodes(range_idx)[-1]):
140+
submodule_node = gm.graph.call_module(
133141
submodule_name, tuple(get_input_nodes(range_idx))
134142
)
135143
prev_node = submodule_node
136144
for idx, original_output in enumerate(get_output_nodes(range_idx)):
137-
with original_gm.graph.inserting_after(prev_node):
138-
new_output_node = original_gm.graph.call_function(
145+
with gm.graph.inserting_after(prev_node):
146+
new_output_node = gm.graph.call_function(
139147
operator.getitem, (submodule_node, idx)
140148
)
141149
node_map[original_output] = new_output_node
@@ -146,31 +154,34 @@ def get_output_nodes(range_idx):
146154
for original_output in get_output_nodes(range_idx):
147155
if original_output not in identity_node_set:
148156
original_output.replace_all_uses_with(node_map[original_output])
157+
new_node2original_node[
158+
node_map[original_output]
159+
] = new_node2original_node[original_output]
149160

150161
# Erase old nodes
151162
for node in reversed(get_body_nodes(range_idx)):
152-
original_gm.graph.erase_node(node)
153-
# print_submodule_call("(fx) after Erase old nodes", original_gm)
163+
gm.graph.erase_node(node)
164+
# print_submodule_call("(fx) after Erase old nodes", gm)
154165

155-
# print_submodule_call("(fx) before recompile", original_gm)
166+
# print_submodule_call("(fx) before recompile", gm)
156167

157-
original_gm.recompile()
168+
gm.recompile()
158169

159-
# print_submodule_call("(fx) after recompile", original_gm)
170+
# print_submodule_call("(fx) after recompile", gm)
160171

161-
return original_gm
172+
return gm
162173

163174

164175
def fold_range_to_submodule(
165-
original_gm: torch.fx.GraphModule,
176+
gm: torch.fx.GraphModule,
166177
start_node_idx: int,
167178
end_node_idx: int,
168179
submodule_hook=None,
169180
submodule_name="extracted_submodule",
170181
group_head_and_tail=True,
171182
):
172183
return convert_to_submodules_graph(
173-
original_gm,
184+
gm,
174185
split_positions=[start_node_idx, end_node_idx],
175186
submodule_hook=submodule_hook,
176187
submodule_name_prefix=submodule_name,
@@ -186,7 +197,7 @@ class NodeProducedOrConsumedCountCtx:
186197

187198

188199
def _get_submodule_inputs_and_outputs(
189-
original_gm: torch.fx.GraphModule,
200+
gm: torch.fx.GraphModule,
190201
start_node_idx: int,
191202
end_node_idx: int,
192203
chain_style=False,
@@ -196,7 +207,7 @@ def _get_submodule_inputs_and_outputs(
196207
defaultdict(int),
197208
defaultdict(int),
198209
)
199-
node_list = list(original_gm.graph.nodes)
210+
node_list = list(gm.graph.nodes)
200211

201212
def get_related_node(node):
202213
for arg in node.args:
@@ -240,7 +251,7 @@ def get_related_node(node):
240251
if count_ctx.node2before_input[node] > 0
241252
if count_ctx.node2body[node] == 0
242253
if count_ctx.node2after_output[node] > 0
243-
][:1]
254+
]
244255
input_nodes_set = set(input_nodes)
245256
input_nodes = [
246257
*input_nodes,
@@ -251,5 +262,4 @@ def get_related_node(node):
251262
*output_nodes,
252263
*[node for node in identity_nodes if node not in output_nodes_set],
253264
]
254-
255265
return input_nodes, output_nodes, identity_nodes

0 commit comments

Comments
 (0)