Skip to content

Commit 4036efa

Browse files
authored
Sort submodule inputs and outputs (#350)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * sort submodule inputs and outputs * avoid sorting outputs of last subgraph
1 parent de2d8c4 commit 4036efa

File tree

3 files changed

+52
-36
lines changed

3 files changed

+52
-36
lines changed

graph_net/test/chain_naive_graph_decomposer_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ read -r -d '' extractor_config_json_str <<EOF
99
{
1010
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
1111
"custom_extractor_config": {
12-
"output_dir": "/tmp/naive_decompose_workspace",
12+
"output_dir": "/tmp/chain_naive_decompose_workspace",
1313
"split_positions": [8, 16, 32],
1414
"group_head_and_tail": true,
1515
"chain_style": true

graph_net/test/naive_graph_decomposer_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ read -r -d '' extractor_config_json_str <<EOF
1010
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
1111
"custom_extractor_config": {
1212
"output_dir": "/tmp/naive_decompose_workspace",
13-
"split_positions": [8, 32],
13+
"split_positions": [8, 16, 32],
1414
"group_head_and_tail": true,
1515
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
1616
"filter_config": {}

graph_net/torch/decompose_util.py

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,23 @@
66

77

88
def convert_to_submodules_graph(
9-
original_gm: torch.fx.GraphModule,
9+
gm: torch.fx.GraphModule,
1010
split_positions: list[int],
1111
submodule_hook=None,
1212
submodule_name_prefix="extracted_submodule",
1313
chain_style=False,
1414
group_head_and_tail=True,
1515
):
1616
"""
17-
chain_style=True: decompose original_gm into g0 * g1 * g2 * g3
17+
chain_style=True: decompose gm into g0 * g1 * g2 * g3
1818
"""
19-
original_gm = copy.deepcopy(original_gm)
19+
gm = copy.deepcopy(gm)
2020
num_placeholders = len(
21-
[node for node in original_gm.graph.nodes if node.op == "placeholder"]
21+
[node for node in gm.graph.nodes if node.op == "placeholder"]
2222
)
2323
submodules_body_nodes = [
2424
node
25-
for node in original_gm.graph.nodes
25+
for node in gm.graph.nodes
2626
if node.op
2727
not in {
2828
"placeholder",
@@ -37,13 +37,16 @@ def convert_to_submodules_graph(
3737
split_positions = [
3838
max(0, min(pos, len(submodules_body_nodes))) for pos in split_positions
3939
]
40-
range_idx2submodule_body_nodes = [
41-
submodules_body_nodes[start:end]
40+
range_idx2range = [
41+
(start, end)
4242
for i in range(len(split_positions) - 1)
4343
for start in [split_positions[i]]
4444
for end in [split_positions[i + 1]]
4545
if end > start
4646
]
47+
range_idx2submodule_body_nodes = [
48+
submodules_body_nodes[start:end] for start, end in range_idx2range
49+
]
4750

4851
def get_body_nodes(range_idx):
4952
return range_idx2submodule_body_nodes[range_idx]
@@ -54,20 +57,20 @@ def get_name2sub_submodule():
5457
)
5558
return {
5659
name: module
57-
for name, module in original_gm.named_modules()
60+
for name, module in gm.named_modules()
5861
if name in used_module_names
5962
}
6063

6164
def get_start_node_idx(range_idx):
6265
start_node = get_body_nodes(range_idx)[0]
63-
for i, node in enumerate(original_gm.graph.nodes):
66+
for i, node in enumerate(gm.graph.nodes):
6467
if node == start_node:
6568
return i
6669
raise NotImplementedError("Dead code.")
6770

6871
def get_end_node_idx(range_idx):
6972
last_node = get_body_nodes(range_idx)[-1]
70-
for i, node in enumerate(original_gm.graph.nodes):
73+
for i, node in enumerate(gm.graph.nodes):
7174
if node == last_node:
7275
return i + 1
7376
raise NotImplementedError("Dead code.")
@@ -78,23 +81,33 @@ def print_submodule_call(prompt, gm):
7881
]
7982
print(f"{prompt} ", submodule_call_stmts)
8083

84+
new_node2original_node = {}
85+
for node in gm.graph.nodes:
86+
new_node2original_node[node] = node
87+
88+
def sort_key(node):
89+
return new_node2original_node[node].name
90+
8191
for range_idx in range(len(range_idx2submodule_body_nodes)):
8292
(
8393
submodule_input_nodes,
8494
submodule_output_nodes,
8595
identity_nodes,
8696
) = _get_submodule_inputs_and_outputs(
87-
original_gm=original_gm,
97+
gm=gm,
8898
start_node_idx=get_start_node_idx(range_idx),
8999
end_node_idx=get_end_node_idx(range_idx),
90100
chain_style=chain_style,
91101
)
92102

93103
def get_input_nodes(range_idx):
94-
return submodule_input_nodes
104+
return sorted(submodule_input_nodes, key=sort_key)
95105

96106
def get_output_nodes(range_idx):
97-
return submodule_output_nodes
107+
end = range_idx2range[range_idx][1]
108+
if end >= len(submodules_body_nodes):
109+
return submodule_output_nodes
110+
return sorted(submodule_output_nodes, key=sort_key)
98111

99112
submodule_name = (
100113
f"{submodule_name_prefix}_{range_idx}"
@@ -107,7 +120,8 @@ def get_output_nodes(range_idx):
107120

108121
# Add placeholder nodes for inputs
109122
for original_node in get_input_nodes(range_idx):
110-
new_node = new_graph.placeholder(original_node.name)
123+
name = new_node2original_node[original_node].name
124+
new_node = new_graph.placeholder(name)
111125
node_map[original_node] = new_node
112126

113127
# Copy body nodes
@@ -116,9 +130,9 @@ def get_output_nodes(range_idx):
116130
node_map[original_node] = new_node
117131

118132
# Add output nodes
119-
output_args = []
120-
for original_node in get_output_nodes(range_idx):
121-
output_args.append(node_map[original_node])
133+
output_args = [
134+
node_map[original_node] for original_node in get_output_nodes(range_idx)
135+
]
122136
new_graph.output(tuple(output_args))
123137

124138
# Create the new GraphModule
@@ -127,15 +141,15 @@ def get_output_nodes(range_idx):
127141
if submodule_hook is not None:
128142
new_sub_module = submodule_hook(new_sub_module, range_idx)
129143
# Replace with submodule node
130-
original_gm.add_submodule(submodule_name, new_sub_module)
131-
with original_gm.graph.inserting_after(get_body_nodes(range_idx)[-1]):
132-
submodule_node = original_gm.graph.call_module(
144+
gm.add_submodule(submodule_name, new_sub_module)
145+
with gm.graph.inserting_after(get_body_nodes(range_idx)[-1]):
146+
submodule_node = gm.graph.call_module(
133147
submodule_name, tuple(get_input_nodes(range_idx))
134148
)
135149
prev_node = submodule_node
136150
for idx, original_output in enumerate(get_output_nodes(range_idx)):
137-
with original_gm.graph.inserting_after(prev_node):
138-
new_output_node = original_gm.graph.call_function(
151+
with gm.graph.inserting_after(prev_node):
152+
new_output_node = gm.graph.call_function(
139153
operator.getitem, (submodule_node, idx)
140154
)
141155
node_map[original_output] = new_output_node
@@ -146,31 +160,34 @@ def get_output_nodes(range_idx):
146160
for original_output in get_output_nodes(range_idx):
147161
if original_output not in identity_node_set:
148162
original_output.replace_all_uses_with(node_map[original_output])
163+
new_node2original_node[
164+
node_map[original_output]
165+
] = new_node2original_node[original_output]
149166

150167
# Erase old nodes
151168
for node in reversed(get_body_nodes(range_idx)):
152-
original_gm.graph.erase_node(node)
153-
# print_submodule_call("(fx) after Erase old nodes", original_gm)
169+
gm.graph.erase_node(node)
170+
# print_submodule_call("(fx) after Erase old nodes", gm)
154171

155-
# print_submodule_call("(fx) before recompile", original_gm)
172+
# print_submodule_call("(fx) before recompile", gm)
156173

157-
original_gm.recompile()
174+
gm.recompile()
158175

159-
# print_submodule_call("(fx) after recompile", original_gm)
176+
# print_submodule_call("(fx) after recompile", gm)
160177

161-
return original_gm
178+
return gm
162179

163180

164181
def fold_range_to_submodule(
165-
original_gm: torch.fx.GraphModule,
182+
gm: torch.fx.GraphModule,
166183
start_node_idx: int,
167184
end_node_idx: int,
168185
submodule_hook=None,
169186
submodule_name="extracted_submodule",
170187
group_head_and_tail=True,
171188
):
172189
return convert_to_submodules_graph(
173-
original_gm,
190+
gm,
174191
split_positions=[start_node_idx, end_node_idx],
175192
submodule_hook=submodule_hook,
176193
submodule_name_prefix=submodule_name,
@@ -186,7 +203,7 @@ class NodeProducedOrConsumedCountCtx:
186203

187204

188205
def _get_submodule_inputs_and_outputs(
189-
original_gm: torch.fx.GraphModule,
206+
gm: torch.fx.GraphModule,
190207
start_node_idx: int,
191208
end_node_idx: int,
192209
chain_style=False,
@@ -196,7 +213,7 @@ def _get_submodule_inputs_and_outputs(
196213
defaultdict(int),
197214
defaultdict(int),
198215
)
199-
node_list = list(original_gm.graph.nodes)
216+
node_list = list(gm.graph.nodes)
200217

201218
def get_related_node(node):
202219
for arg in node.args:
@@ -240,7 +257,7 @@ def get_related_node(node):
240257
if count_ctx.node2before_input[node] > 0
241258
if count_ctx.node2body[node] == 0
242259
if count_ctx.node2after_output[node] > 0
243-
][:1]
260+
]
244261
input_nodes_set = set(input_nodes)
245262
input_nodes = [
246263
*input_nodes,
@@ -251,5 +268,4 @@ def get_related_node(node):
251268
*output_nodes,
252269
*[node for node in identity_nodes if node not in output_nodes_set],
253270
]
254-
255271
return input_nodes, output_nodes, identity_nodes

0 commit comments

Comments
 (0)