Skip to content

Commit d6bccea

Browse files
authored
1) fix bugs in decomposer_util.py; 2) support tools/typical_sequence_decompose.sh (#445)
* init 'symbolic_dimension_reifier' field in graph_net.json * remove unused files * add DeviceRewriteSamplePass * fix output node order bug in torch/decompose_util.py * update graph_hash.txt in OnlyModelFileRewriteSamplePassMixin * 1) fix bugs in decomposer_util.py; 2) support tools/typical_sequence_decompose.sh
1 parent b33b2d7 commit d6bccea

File tree

4 files changed

+119
-35
lines changed

4 files changed

+119
-35
lines changed

graph_net/tensor_meta.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import importlib.util as imp
22
import inspect
33
from dataclasses import dataclass
4+
import math
45

56

67
@dataclass
@@ -83,6 +84,8 @@ def serialize_to_py_str(self) -> str:
8384
def _get_limited_precision_float_str(self, value):
8485
if not isinstance(value, float):
8586
return value
87+
if math.isnan(value) or math.isinf(value):
88+
return f'float("{value}")'
8689
return f"{value:.3f}"
8790

8891
def _format_data(self, data):
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#!/bin/bash
2+
set -x
3+
4+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
5+
DECOMPOSE_WORKSPACE=/tmp/typical_sequence_decompose_workspace
6+
7+
mkdir -p "$DECOMPOSE_WORKSPACE"
8+
9+
model_list="$GRAPH_NET_ROOT/graph_net/test/dev_model_list/validation_error_model_list.txt"
10+
11+
python3 -m graph_net.model_path_handler \
12+
--model-path-list $model_list \
13+
--handler-config=$(base64 -w 0 <<EOF
14+
{
15+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/typical_sequence_split_points.py",
16+
"handler_class_name": "OpNamesExtractor",
17+
"handler_config": {
18+
"resume": true,
19+
"model_path_prefix": "$GRAPH_NET_ROOT",
20+
"output_dir": "$DECOMPOSE_WORKSPACE"
21+
}
22+
}
23+
EOF
24+
)
25+
26+
python3 -m graph_net.torch.typical_sequence_split_points \
27+
--enable-resume \
28+
--model-list "$model_list" \
29+
--op-names-path-prefix "$DECOMPOSE_WORKSPACE" \
30+
--device "cuda" \
31+
--window-size 10 \
32+
--fold-policy default \
33+
--fold-times 10 \
34+
--output-json "$DECOMPOSE_WORKSPACE/split_results.json"
35+
36+
python3 -m graph_net.model_path_handler \
37+
--model-path-list $model_list \
38+
--handler-config=$(base64 -w 0 <<EOF
39+
{
40+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/graph_decomposer.py",
41+
"handler_class_name": "RangeDecomposerExtractor",
42+
"handler_config": {
43+
"resume": true,
44+
"model_path_prefix": "$GRAPH_NET_ROOT",
45+
"output_dir": "$DECOMPOSE_WORKSPACE",
46+
"split_results_path": "$DECOMPOSE_WORKSPACE/split_results.json",
47+
"group_head_and_tail": true,
48+
"chain_style": false
49+
}
50+
}
51+
EOF
52+
)
53+
54+
subgraph_sample_list=$DECOMPOSE_WORKSPACE/subgraph_sample_list.txt
55+
cat $model_list \
56+
| grep -v '# ' \
57+
| xargs -I {} find $DECOMPOSE_WORKSPACE/{} -name "model.py" \
58+
| xargs dirname \
59+
| xargs realpath --relative-to=$DECOMPOSE_WORKSPACE \
60+
| tee $subgraph_sample_list
61+
62+
GRAPH_VAR_RENAME_WORKSPACE=$DECOMPOSE_WORKSPACE/graph_var_renamed
63+
64+
python3 -m graph_net.model_path_handler \
65+
--model-path-list $subgraph_sample_list \
66+
--handler-config=$(base64 -w 0 <<EOF
67+
{
68+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/graph_variable_renamer.py",
69+
"handler_class_name": "GraphVariableRenamer",
70+
"handler_config": {
71+
"model_path_prefix": "$DECOMPOSE_WORKSPACE",
72+
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
73+
"data_input_predicator_class_name": "NaiveDataInputPredicator",
74+
"model_runnable_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
75+
"model_runnable_predicator_class_name": "ModelRunnablePredicator",
76+
"output_dir": "$GRAPH_VAR_RENAME_WORKSPACE"
77+
}
78+
}
79+
EOF
80+
)
81+

graph_net/torch/decompose_util.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def print_submodule_call(prompt, gm):
8585
def sort_key(node):
8686
return new_node2original_node[node].name
8787

88-
for range_idx in range(len(range_idx2submodule_body_nodes)):
88+
num_subgraphs = len(range_idx2submodule_body_nodes)
89+
for range_idx in range(num_subgraphs):
8990
(
9091
submodule_input_nodes,
9192
submodule_output_nodes,
@@ -96,6 +97,7 @@ def sort_key(node):
9697
end_node_idx=get_end_node_idx(range_idx),
9798
chain_style=chain_style,
9899
)
100+
identity_node_set = set(identity_nodes)
99101

100102
def get_input_nodes(range_idx):
101103
return sorted(submodule_input_nodes, key=sort_key)
@@ -153,13 +155,13 @@ def get_output_nodes(range_idx):
153155
prev_node = new_output_node
154156

155157
# Replace all use of outputs
156-
identity_node_set = set(identity_nodes)
157158
for original_output in get_output_nodes(range_idx):
158-
if original_output not in identity_node_set:
159-
original_output.replace_all_uses_with(node_map[original_output])
160-
new_node2original_node[
161-
node_map[original_output]
162-
] = new_node2original_node[original_output]
159+
if original_output in identity_node_set:
160+
continue
161+
original_output.replace_all_uses_with(node_map[original_output])
162+
new_node2original_node[node_map[original_output]] = new_node2original_node[
163+
original_output
164+
]
163165

164166
# Erase old nodes
165167
for node in reversed(get_body_nodes(range_idx)):
@@ -215,12 +217,18 @@ def _get_submodule_inputs_and_outputs(
215217
return minimal_input_nodes, minimal_output_nodes, []
216218
else:
217219
node_list = list(gm.graph.nodes)
218-
input_nodes, _ = _get_minimal_submodule_inputs_and_outputs(
219-
gm=gm, start_node_idx=start_node_idx, end_node_idx=len(node_list)
220-
)
221-
output_nodes, _ = _get_minimal_submodule_inputs_and_outputs(
222-
gm=gm, start_node_idx=end_node_idx, end_node_idx=len(node_list)
223-
)
220+
if _is_node_idx_out_of_range(gm, start_node_idx):
221+
input_nodes = list(_get_return_nodes(gm))
222+
else:
223+
input_nodes, _ = _get_minimal_submodule_inputs_and_outputs(
224+
gm=gm, start_node_idx=start_node_idx, end_node_idx=len(node_list)
225+
)
226+
if _is_node_idx_out_of_range(gm, end_node_idx):
227+
output_nodes = list(_get_return_nodes(gm))
228+
else:
229+
output_nodes, _ = _get_minimal_submodule_inputs_and_outputs(
230+
gm=gm, start_node_idx=end_node_idx, end_node_idx=len(node_list)
231+
)
224232
identity_nodes_set = set(input_nodes) & set(output_nodes)
225233
identity_nodes = [node for node in input_nodes if node in identity_nodes_set]
226234
return input_nodes, output_nodes, identity_nodes
@@ -275,25 +283,19 @@ def get_args_node_and_self_node(node):
275283
for related_node in get_args_node_and_self_node(node):
276284
count_ctx.node2after_output[related_node] += 1
277285

278-
if _is_node_idx_out_of_range(gm, start_node_idx):
279-
input_nodes = list(_get_return_nodes(gm))
280-
else:
281-
input_nodes = [
282-
node
283-
for node in node_list
284-
if count_ctx.node2before_input[node] > 0
285-
if count_ctx.node2body[node] > 0
286-
]
287-
if _is_node_idx_out_of_range(gm, end_node_idx):
288-
output_nodes = list(_get_return_nodes(gm))
289-
else:
290-
output_nodes = [
291-
node
292-
for node in node_list
293-
if not (count_ctx.node2before_input[node] > 0)
294-
if count_ctx.node2body[node] > 0
295-
if count_ctx.node2after_output[node] > 0
296-
]
286+
input_nodes = [
287+
node
288+
for node in node_list
289+
if count_ctx.node2before_input[node] > 0
290+
if count_ctx.node2body[node] > 0
291+
]
292+
output_nodes = [
293+
node
294+
for node in node_list
295+
if not (count_ctx.node2before_input[node] > 0)
296+
if count_ctx.node2body[node] > 0
297+
if count_ctx.node2after_output[node] > 0
298+
]
297299
return input_nodes, output_nodes
298300

299301

graph_net/torch/graph_variable_renamer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,8 @@ def __call__(self, rel_model_path):
7979
module, inputs = get_torch_module_and_inputs(src_model_path)
8080
gm = parse_sole_graph_module(module, inputs)
8181
gm = self.rename_graph_variables(gm, inputs, src_model_path)
82-
model_name = os.path.basename(rel_model_path.rstrip(os.sep))
83-
new_rel_path = f"{model_name}_renamed"
8482
dst_model_path = os.path.realpath(
85-
os.path.join(self.config["output_dir"], new_rel_path)
83+
os.path.join(self.config["output_dir"], rel_model_path)
8684
)
8785
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
8886
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)

0 commit comments

Comments
 (0)