Skip to content

Commit 1841d0f

Browse files
committed
fix output node order bug in torch/decompose_util.py
1 parent 646f4c8 commit 1841d0f

File tree

3 files changed

+81
-42
lines changed

3 files changed

+81
-42
lines changed
Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#!/bin/bash
22

33
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4-
model_path_handler_config_json_str=$(cat <<EOF
4+
5+
python3 -m graph_net.model_path_handler \
6+
--model-path-list "customize_your_model_path_list" \
7+
--handler-config $(base64 -w 0 <<EOF
58
{
69
"handler_path": "$GRAPH_NET_ROOT/graph_net/customize_your_sample_pass.py",
710
"handler_class_name": "customize_your_class_name",
@@ -13,14 +16,3 @@ model_path_handler_config_json_str=$(cat <<EOF
1316
}
1417
EOF
1518
)
16-
17-
model_path_handler_model_path_list="customize_your_model_path_list"
18-
MODEL_PATH_HANDLER_CONFIG=$(echo $model_path_handler_config_json_str | base64 -w 0)
19-
20-
python3 -m graph_net.model_path_handler \
21-
--model-path-list $model_path_handler_model_path_list \
22-
--handler-config $MODEL_PATH_HANDLER_CONFIG \
23-
24-
unset model_path_handler_model_path_list
25-
unset MODEL_PATH_HANDLER_CONFIG
26-

graph_net/test/typical_sequence_decomposer_test.sh

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/bin/bash
2+
set -x
23

34
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
45
DECOMPOSE_PATH=/tmp/decompose_workspace
@@ -9,7 +10,9 @@ mkdir -p "$DECOMPOSE_PATH"
910
# model_list="$GRAPH_NET_ROOT/graph_net/config/small100_torch_samples_list.txt"
1011
model_list="$GRAPH_NET_ROOT/graph_net/test/dev_model_list/validation_error_model_list.txt"
1112

12-
op_names_extractor_config_json_str=$(cat <<EOF
13+
python3 -m graph_net.model_path_handler \
14+
--model-path-list $model_list \
15+
--handler-config=$(base64 -w 0 <<EOF
1316
{
1417
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/typical_sequence_split_points.py",
1518
"handler_class_name": "OpNamesExtractor",
@@ -21,11 +24,6 @@ op_names_extractor_config_json_str=$(cat <<EOF
2124
}
2225
EOF
2326
)
24-
OP_NAMES_EXTRACTOR_CONFIG=$(echo $op_names_extractor_config_json_str | base64 -w 0)
25-
26-
python3 -m graph_net.model_path_handler \
27-
--model-path-list $model_list \
28-
--handler-config=$OP_NAMES_EXTRACTOR_CONFIG \
2927

3028
python3 -m graph_net.torch.typical_sequence_split_points \
3129
--enable-resume \
@@ -37,7 +35,9 @@ python3 -m graph_net.torch.typical_sequence_split_points \
3735
--fold-times 10 \
3836
--output-json "$DECOMPOSE_PATH/split_results.json"
3937

40-
decompose_config_json_str=$(cat <<EOF
38+
python3 -m graph_net.model_path_handler \
39+
--model-path-list $model_list \
40+
--handler-config=$(base64 -w 0 <<EOF
4141
{
4242
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/graph_decomposer.py",
4343
"handler_class_name": "RangeDecomposerExtractor",
@@ -52,27 +52,46 @@ decompose_config_json_str=$(cat <<EOF
5252
}
5353
EOF
5454
)
55-
DECOMPOSE_CONFIG=$(echo $decompose_config_json_str | base64 -w 0)
5655

57-
python3 -m graph_net.model_path_handler \
58-
--model-path-list $model_list \
59-
--handler-config=$DECOMPOSE_CONFIG \
56+
device_rewrite_sample_list=$DECOMPOSE_PATH/device_rewrite_sample_list.txt
57+
cat $model_list \
58+
| grep -v '# ' \
59+
| xargs -I {} find $DECOMPOSE_PATH/{} -name "model.py" \
60+
| xargs dirname \
61+
| xargs realpath --relative-to=$DECOMPOSE_PATH \
62+
| tee $device_rewrite_sample_list
6063

61-
test_compiler_config_json_str=$(cat <<EOF
64+
DEVICE_REWRITE_WORKSPACE=$DECOMPOSE_PATH/device_rewrite
65+
66+
python3 -m graph_net.model_path_handler \
67+
--model-path-list $device_rewrite_sample_list \
68+
--handler-config $(base64 -w 0 <<EOF
6269
{
63-
"model_path_prefix": "$GRAPH_NET_ROOT",
64-
"decomposed_root": "$DECOMPOSE_PATH"
70+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/device_rewrite_sample_pass.py",
71+
"handler_class_name": "DeviceRewriteSamplePass",
72+
"handler_config": {
73+
"device": "cuda",
74+
"resume": false,
75+
"model_path_prefix": "$DECOMPOSE_PATH",
76+
"output_dir": "$DEVICE_REWRITE_WORKSPACE"
77+
}
6578
}
6679
EOF
6780
)
68-
TEST_COMPILER_CONFIG=$(echo $test_compiler_config_json_str | base64 -w 0)
81+
6982

7083
python3 -m graph_net.torch.test_compiler \
84+
--model-path-prefix $GRAPH_NET_ROOT \
7185
--allow-list $model_list \
7286
--compiler range_decomposer_validator \
7387
--device cuda \
74-
--config $TEST_COMPILER_CONFIG \
75-
--model-path-prefix $GRAPH_NET_ROOT \
88+
--config $(base64 -w 0 <<EOF
89+
{
90+
"model_path_prefix": "$GRAPH_NET_ROOT",
91+
"decomposed_root": "$DEVICE_REWRITE_WORKSPACE"
92+
}
93+
EOF
94+
) \
7695
2>&1 | tee "$DECOMPOSE_PATH/validation.log"
7796

7897
python3 -m graph_net.plot_ESt \

graph_net/torch/decompose_util.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def _get_minimal_submodule_inputs_and_outputs(
237237
defaultdict(int),
238238
)
239239
node_list = list(gm.graph.nodes)
240+
assert end_node_idx <= len(node_list)
240241

241242
def get_args_node(arg):
242243
if isinstance(arg, torch.fx.Node):
@@ -274,17 +275,44 @@ def get_args_node_and_self_node(node):
274275
for related_node in get_args_node_and_self_node(node):
275276
count_ctx.node2after_output[related_node] += 1
276277

277-
input_nodes = [
278-
node
279-
for node in node_list
280-
if count_ctx.node2before_input[node] > 0
281-
if count_ctx.node2body[node] > 0
282-
]
283-
output_nodes = [
284-
node
285-
for node in node_list
286-
if not (count_ctx.node2before_input[node] > 0)
287-
if count_ctx.node2body[node] > 0
288-
if count_ctx.node2after_output[node] > 0
289-
]
278+
if _is_node_idx_out_of_range(gm, start_node_idx):
279+
input_nodes = list(_get_return_nodes(gm))
280+
else:
281+
input_nodes = [
282+
node
283+
for node in node_list
284+
if count_ctx.node2before_input[node] > 0
285+
if count_ctx.node2body[node] > 0
286+
]
287+
if _is_node_idx_out_of_range(gm, end_node_idx):
288+
output_nodes = list(_get_return_nodes(gm))
289+
else:
290+
output_nodes = [
291+
node
292+
for node in node_list
293+
if not (count_ctx.node2before_input[node] > 0)
294+
if count_ctx.node2body[node] > 0
295+
if count_ctx.node2after_output[node] > 0
296+
]
290297
return input_nodes, output_nodes
298+
299+
300+
def _get_return_nodes(gm):
301+
for node in gm.graph.nodes:
302+
if node.op != "output":
303+
continue
304+
for arg in node.args:
305+
if isinstance(arg, (tuple, list)):
306+
yield from arg
307+
else:
308+
yield arg
309+
310+
311+
def _is_node_idx_out_of_range(gm, node_idx: int):
312+
node_list = list(gm.graph.nodes)
313+
num_nodes = len(node_list)
314+
if node_idx < 0:
315+
return True
316+
if node_idx >= num_nodes:
317+
return True
318+
return node_list[node_idx].op in {"output", "placeholder"}

0 commit comments

Comments
 (0)