Skip to content

Commit e9d6055

Browse files
authored
Naive chainable decompose minor fix (#345)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * minor fix for naive_graph_decomposer * bug fix for chain_style
1 parent 06b8dc2 commit e9d6055

File tree

2 files changed

+33
-25
lines changed

2 files changed

+33
-25
lines changed

graph_net/test/chain_naive_graph_decomposer_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ MODEL_PATH_IN_SAMPLES=/timm/resnet18
77
read -r -d '' json_str <<'EOF'
88
{
99
"output_dir": "/tmp/naive_decompose_workspace",
10-
"split_positions": [2, 4],
11-
"group_head_and_tail": false,
10+
"split_positions": [8, 16, 32],
11+
"group_head_and_tail": true,
1212
"chain_style": true
1313
}
1414
EOF

graph_net/torch/decompose_util.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def print_submodule_call(prompt, gm):
8282
(
8383
submodule_input_nodes,
8484
submodule_output_nodes,
85+
identity_nodes,
8586
) = _get_submodule_inputs_and_outputs(
8687
original_gm=original_gm,
8788
start_node_idx=get_start_node_idx(range_idx),
@@ -141,8 +142,10 @@ def get_output_nodes(range_idx):
141142
prev_node = new_output_node
142143

143144
# Replace all use of outputs
145+
identity_node_set = set(identity_nodes)
144146
for original_output in get_output_nodes(range_idx):
145-
original_output.replace_all_uses_with(node_map[original_output])
147+
if original_output not in identity_node_set:
148+
original_output.replace_all_uses_with(node_map[original_output])
146149

147150
# Erase old nodes
148151
for node in reversed(get_body_nodes(range_idx)):
@@ -215,33 +218,38 @@ def get_related_node(node):
215218
for related_node in get_related_node(node):
216219
count_ctx.node2after_output[related_node] += 1
217220

218-
if chain_style:
219-
input_nodes = [
220-
node
221-
for node in node_list
222-
if (count_ctx.node2before_input[node] > 0)
223-
if (count_ctx.node2body[node] > 0 or count_ctx.node2after_output[node] > 0)
224-
]
225-
input_nodes_set = set(input_nodes)
226-
output_nodes = [
227-
node
228-
for node in node_list
229-
if (count_ctx.node2before_input[node] > 0 or count_ctx.node2body[node] > 0)
230-
if (count_ctx.node2after_output[node] > 0)
231-
]
221+
input_nodes = [
222+
node
223+
for node in node_list
224+
if count_ctx.node2before_input[node] > 0
225+
if count_ctx.node2body[node] > 0
226+
]
227+
output_nodes = [
228+
node
229+
for node in node_list
230+
if not (count_ctx.node2before_input[node] > 0)
231+
if count_ctx.node2body[node] > 0
232+
if count_ctx.node2after_output[node] > 0
233+
]
234+
if not chain_style:
235+
identity_nodes = []
232236
else:
233-
input_nodes = [
237+
identity_nodes = [
234238
node
235239
for node in node_list
236240
if count_ctx.node2before_input[node] > 0
237-
if count_ctx.node2body[node] > 0
241+
if count_ctx.node2body[node] == 0
242+
if count_ctx.node2after_output[node] > 0
243+
][:1]
244+
input_nodes_set = set(input_nodes)
245+
input_nodes = [
246+
*input_nodes,
247+
*[node for node in identity_nodes if node not in input_nodes_set],
238248
]
249+
output_nodes_set = set(output_nodes)
239250
output_nodes = [
240-
node
241-
for node in node_list
242-
if not (count_ctx.node2before_input[node] > 0)
243-
if count_ctx.node2body[node] > 0
244-
if count_ctx.node2after_output[node] > 0
251+
*output_nodes,
252+
*[node for node in identity_nodes if node not in output_nodes_set],
245253
]
246254

247-
return input_nodes, output_nodes
255+
return input_nodes, output_nodes, identity_nodes

0 commit comments

Comments
 (0)