Skip to content

Commit b06e8ec

Browse files
committed
bug fix for chain_style
1 parent 269309b commit b06e8ec

File tree

3 files changed

+34
-37
lines changed

3 files changed

+34
-37
lines changed

graph_net/test/chain_naive_graph_decomposer_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ MODEL_PATH_IN_SAMPLES=/timm/resnet18
77
read -r -d '' json_str <<'EOF'
88
{
99
"output_dir": "/tmp/naive_decompose_workspace",
10-
"split_positions": [2, 4],
11-
"group_head_and_tail": false,
10+
"split_positions": [8, 16, 32],
11+
"group_head_and_tail": true,
1212
"chain_style": true
1313
}
1414
EOF

graph_net/test/naive_graph_decomposer_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ MODEL_PATH_IN_SAMPLES=/timm/resnet18
55
read -r -d '' json_str <<'EOF'
66
{
77
"output_dir": "/tmp/naive_decompose_workspace",
8-
"split_positions": [8, 32],
8+
"split_positions": [0, 32],
99
"group_head_and_tail": true
1010
}
1111
EOF

graph_net/torch/decompose_util.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,11 @@ def get_end_node_idx(range_idx):
7272
return i + 1
7373
raise NotImplementedError("Dead code.")
7474

75-
def print_submodule_call(prompt, gm):
76-
submodule_call_stmts = [
77-
stmt for stmt in gm.code.split("\n") if "self.extracted_submodule" in stmt
78-
]
79-
print(f"{prompt} ", submodule_call_stmts)
80-
8175
for range_idx in range(len(range_idx2submodule_body_nodes)):
8276
(
8377
submodule_input_nodes,
8478
submodule_output_nodes,
79+
identity_nodes,
8580
) = _get_submodule_inputs_and_outputs(
8681
original_gm=original_gm,
8782
start_node_idx=get_start_node_idx(range_idx),
@@ -141,20 +136,17 @@ def get_output_nodes(range_idx):
141136
prev_node = new_output_node
142137

143138
# Replace all use of outputs
139+
identity_node_set = set(identity_nodes)
144140
for original_output in get_output_nodes(range_idx):
145-
original_output.replace_all_uses_with(node_map[original_output])
141+
if original_output not in identity_node_set:
142+
original_output.replace_all_uses_with(node_map[original_output])
146143

147144
# Erase old nodes
148145
for node in reversed(get_body_nodes(range_idx)):
149146
original_gm.graph.erase_node(node)
150-
# print_submodule_call("(fx) after Erase old nodes", original_gm)
151-
152-
# print_submodule_call("(fx) before recompile", original_gm)
153147

154148
original_gm.recompile()
155149

156-
# print_submodule_call("(fx) after recompile", original_gm)
157-
158150
return original_gm
159151

160152

@@ -215,33 +207,38 @@ def get_related_node(node):
215207
for related_node in get_related_node(node):
216208
count_ctx.node2after_output[related_node] += 1
217209

218-
if chain_style:
219-
input_nodes = [
220-
node
221-
for node in node_list
222-
if (count_ctx.node2before_input[node] > 0)
223-
if (count_ctx.node2body[node] > 0 or count_ctx.node2after_output[node] > 0)
224-
]
225-
input_nodes_set = set(input_nodes)
226-
output_nodes = [
227-
node
228-
for node in node_list
229-
if (count_ctx.node2before_input[node] > 0 or count_ctx.node2body[node] > 0)
230-
if (count_ctx.node2after_output[node] > 0)
231-
]
210+
input_nodes = [
211+
node
212+
for node in node_list
213+
if count_ctx.node2before_input[node] > 0
214+
if count_ctx.node2body[node] > 0
215+
]
216+
output_nodes = [
217+
node
218+
for node in node_list
219+
if not (count_ctx.node2before_input[node] > 0)
220+
if count_ctx.node2body[node] > 0
221+
if count_ctx.node2after_output[node] > 0
222+
]
223+
if not chain_style:
224+
identity_nodes = []
232225
else:
233-
input_nodes = [
226+
identity_nodes = [
234227
node
235228
for node in node_list
236229
if count_ctx.node2before_input[node] > 0
237-
if count_ctx.node2body[node] > 0
230+
if count_ctx.node2body[node] == 0
231+
if count_ctx.node2after_output[node] > 0
232+
][:1]
233+
input_nodes_set = set(input_nodes)
234+
input_nodes = [
235+
*input_nodes,
236+
*[node for node in identity_nodes if node not in input_nodes_set],
238237
]
238+
output_nodes_set = set(output_nodes)
239239
output_nodes = [
240-
node
241-
for node in node_list
242-
if not (count_ctx.node2before_input[node] > 0)
243-
if count_ctx.node2body[node] > 0
244-
if count_ctx.node2after_output[node] > 0
240+
*output_nodes,
241+
*[node for node in identity_nodes if node not in output_nodes_set],
245242
]
246243

247-
return input_nodes, output_nodes
244+
return input_nodes, output_nodes, identity_nodes

0 commit comments

Comments
 (0)