Skip to content

Commit 14abe6a

Browse files
authored
Subgraph Shape propagation (#512)
* init 'symbolic_dimension_reifier' field in graph_net.json * remove unused files * backup code * add sample_pass/group_ranges_from_subgraph_sources.py * add sample pass torch/sample_pass/subgraph_input_producer_indexes_generator.py * add sample pass torch/sample_pass/shape_propagator.py * save subgraph relative model paths
1 parent 5068a8b commit 14abe6a

File tree

19 files changed

+569
-5
lines changed

19 files changed

+569
-5
lines changed

graph_net/apply_sample_pass.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,32 @@ def _get_handler(args):
3939
def main(args):
4040
handler = _get_handler(args)
4141
if args.model_path is not None:
42+
assert not hasattr(handler, "BEGIN")
43+
assert not hasattr(handler, "END")
4244
handle_model_path(handler, args.model_path)
4345
elif args.use_subprocess:
46+
assert not hasattr(handler, "BEGIN")
47+
assert not hasattr(handler, "END")
4448
handle_model_path_list_in_subprocess(args)
4549
else:
4650
handle_model_path_list_in_current_process(handler, args)
4751

4852

4953
def handle_model_path_list_in_current_process(handler, args):
50-
for model_path in _get_model_path_list(args):
54+
rel_model_paths = list(_get_model_path_list(args))
55+
if hasattr(handler, "BEGIN"):
56+
handler.BEGIN(rel_model_paths)
57+
for rel_model_path in rel_model_paths:
5158
try:
52-
handle_model_path(handler, model_path)
59+
handle_model_path(handler, rel_model_path)
5360
except KeyboardInterrupt:
5461
print("KeyboardInterrupt")
5562
return
5663
except Exception:
5764
print("------------[apply_sample_pass failed]------------", flush=True)
5865
traceback.print_exc()
66+
if hasattr(handler, "END"):
67+
handler.END(rel_model_paths)
5968

6069

6170
def handle_model_path_list_in_subprocess(args):
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
5+
python3 -m graph_net.apply_sample_pass \
6+
--model-path-list "customize_your_model_path_list" \
7+
--sample-pass-file-path "$GRAPH_NET_ROOT/graph_net/customize_your_sample_pass.py" \
8+
--sample-pass-class-name customize_your_class_name \
9+
--sample-pass-config $(base64 -w 0 <<EOF
10+
{
11+
"resume": true,
12+
"model_path_prefix": "/customize_your_model_path_prefix",
13+
"output_dir": "/customize_your_output_file"
14+
}
15+
EOF
16+
)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from graph_net.sample_pass.sample_pass import SamplePass
2+
from pathlib import Path
3+
import json
4+
5+
6+
class GroupRangesFromSubgraphSources(SamplePass):
7+
def __init__(self, config=None):
8+
super().__init__(config)
9+
self.original_graph_rel_model_path2ranges: dict[str, list[(int, int)]] = {}
10+
self.original_graph_rel_model_path2subgraph_rel_model_paths: dict[
11+
str, list[str]
12+
] = {}
13+
14+
def declare_config(
15+
self,
16+
subgraph_model_path_prefix: str,
17+
output_dir: str,
18+
subgraph_sources_json_file_name: str = "subgraph_sources.json",
19+
output_json_file_name: str = "grouped_ranges_from_subgraph_sources.json",
20+
output_json_key: str = "grouped_ranges_from_subgraph_sources",
21+
output_json_subgraph_rel_model_path_key: str = "subgraph_rel_model_paths",
22+
):
23+
pass
24+
25+
def __call__(self, subgraph_rel_model_path: str):
26+
model_path = (
27+
Path(self.config["subgraph_model_path_prefix"])
28+
/ subgraph_rel_model_path
29+
/ self.config["subgraph_sources_json_file_name"]
30+
)
31+
subgraph_sources = json.load(open(model_path))
32+
for original_graph_rel_model_path, subgraph_ranges in subgraph_sources.items():
33+
self._collect_original_graph_rel_model_path2ranges(
34+
original_graph_rel_model_path, subgraph_ranges
35+
)
36+
self._collect_original_graph_rel_model_path2subgraph_rel_model_path(
37+
original_graph_rel_model_path,
38+
[subgraph_rel_model_path] * len(subgraph_ranges),
39+
)
40+
41+
def _collect_original_graph_rel_model_path2subgraph_rel_model_path(
42+
self,
43+
original_graph_rel_model_path: str,
44+
subgraph_rel_model_paths: list[str],
45+
):
46+
old = self.original_graph_rel_model_path2subgraph_rel_model_paths.get(
47+
original_graph_rel_model_path, []
48+
)
49+
self.original_graph_rel_model_path2subgraph_rel_model_paths[
50+
original_graph_rel_model_path
51+
] = [
52+
*old,
53+
*subgraph_rel_model_paths,
54+
]
55+
56+
def _collect_original_graph_rel_model_path2ranges(
57+
self, original_graph_rel_model_path, subgraph_ranges
58+
):
59+
old_ranges = self.original_graph_rel_model_path2ranges.get(
60+
original_graph_rel_model_path, []
61+
)
62+
self.original_graph_rel_model_path2ranges[original_graph_rel_model_path] = [
63+
*old_ranges,
64+
*subgraph_ranges,
65+
]
66+
67+
def END(self, rel_model_paths: list[str]):
68+
for (
69+
original_graph_rel_model_path,
70+
subgraph_ranges,
71+
) in self.original_graph_rel_model_path2ranges.items():
72+
subgraph_rel_model_paths = (
73+
self.original_graph_rel_model_path2subgraph_rel_model_paths[
74+
original_graph_rel_model_path
75+
]
76+
)
77+
self._save_json(
78+
original_graph_rel_model_path, subgraph_ranges, subgraph_rel_model_paths
79+
)
80+
81+
def _save_json(
82+
self, original_graph_rel_model_path, subgraph_ranges, subgraph_rel_model_paths
83+
):
84+
model_dir = Path(self.config["output_dir"]) / original_graph_rel_model_path
85+
model_dir.mkdir(parents=True, exist_ok=True)
86+
ranges_json = self._get_ranges_json(subgraph_ranges)
87+
paths_json = self._get_paths_json(subgraph_rel_model_paths)
88+
json_obj = {**ranges_json, **paths_json}
89+
json_str = json.dumps(json_obj, indent=4)
90+
(model_dir / self.config["output_json_file_name"]).write_text(json_str)
91+
92+
def _get_paths_json(self, subgraph_rel_model_paths: list[str]):
93+
json_obj = {
94+
self.config[
95+
"output_json_subgraph_rel_model_path_key"
96+
]: subgraph_rel_model_paths
97+
}
98+
return json_obj
99+
100+
def _get_ranges_json(self, subgraph_ranges: list[(int, int)]):
101+
json_obj = {self.config["output_json_key"]: subgraph_ranges}
102+
return json_obj

graph_net/sample_pass/sample_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _make_config_by_config_declare(self, config):
6161
class_name = type(self).__name__
6262
assert name in mut_config, f"{name=} {class_name=}"
6363

64-
def get_extra_config_fields():
64+
def get_undefined_config_fields():
6565
return set(name for name, _ in mut_config.items()) - set(
6666
name for name, _ in sig.parameters.items()
6767
)
@@ -71,10 +71,10 @@ def get_extra_config_fields():
7171
for _, param in sig.parameters.items()
7272
)
7373
if no_varadic_keyword:
74-
no_extra_config_fields = all(
74+
no_undefined_config_fields = all(
7575
name in sig.parameters for name, _ in mut_config.items()
7676
)
77-
assert no_extra_config_fields, f"{get_extra_config_fields()=}"
77+
assert no_undefined_config_fields, f"{get_undefined_config_fields()=}"
7878
return mut_config
7979

8080
def _complete_default(self, name, param, mut_config):
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
5+
python3 -m graph_net.apply_sample_pass \
6+
--model-path-list "$GRAPH_NET_ROOT/graph_net/test/workspace_group_ranges_from_subgraph_sources/sample_list.txt" \
7+
--sample-pass-file-path "$GRAPH_NET_ROOT/graph_net/sample_pass/group_ranges_from_subgraph_sources.py" \
8+
--sample-pass-class-name GroupRangesFromSubgraphSources \
9+
--sample-pass-config $(base64 -w 0 <<EOF
10+
{
11+
"subgraph_model_path_prefix": "$GRAPH_NET_ROOT/graph_net/test/workspace_group_ranges_from_subgraph_sources",
12+
"output_dir": "/tmp/workspace_group_ranges_from_subgraph_sources"
13+
}
14+
EOF
15+
)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
5+
python3 -m graph_net.apply_sample_pass \
6+
--model-path-list "$GRAPH_NET_ROOT/graph_net/test/workspace_subgraph_input_producer_indexes_generator/sample_list.txt" \
7+
--sample-pass-file-path "$GRAPH_NET_ROOT/graph_net/torch/sample_pass/shape_propagator.py" \
8+
--sample-pass-class-name ShapePropagator \
9+
--sample-pass-config $(base64 -w 0 <<EOF
10+
{
11+
"resume": false,
12+
"model_path_prefix": "$GRAPH_NET_ROOT",
13+
"output_dir": "/tmp/workspace_shape_propagator"
14+
}
15+
EOF
16+
)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
5+
python3 -m graph_net.apply_sample_pass \
6+
--model-path-list "$GRAPH_NET_ROOT/graph_net/test/workspace_subgraph_input_producer_indexes_generator/sample_list.txt" \
7+
--sample-pass-file-path "$GRAPH_NET_ROOT/graph_net/torch/sample_pass/subgraph_input_producer_indexes_generator.py" \
8+
--sample-pass-class-name SubgraphInputProducerIndexesGenerator \
9+
--sample-pass-config $(base64 -w 0 <<EOF
10+
{
11+
"resume": false,
12+
"model_path_prefix": "$GRAPH_NET_ROOT",
13+
"subgraph_ranges_json_root": "$GRAPH_NET_ROOT/graph_net/test/workspace_subgraph_input_producer_indexes_generator",
14+
"subgraph_ranges_json_file_name": "grouped_ranges_from_subgraph_sources.json",
15+
"subgraph_ranges_json_key": "grouped_ranges_from_subgraph_sources",
16+
"output_dir": "/tmp/workspace_subgraph_input_producer_indexes_generator"
17+
}
18+
EOF
19+
)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k/_decomposed/mobilenetv4_conv_aa_large.e230_r384_in12k_start33_end38_6
2+
samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k/_decomposed/mobilenetv4_conv_aa_large.e230_r384_in12k_start55_end60_10
3+
samples/timm/mobilenetv4_conv_aa_large.e230_r384_in12k/_decomposed/mobilenetv4_conv_aa_large.e230_r384_in12k_start60_end64_11
4+
samples/timm/convnextv2_base.fcmae_ft_in1k/_decomposed/convnextv2_base.fcmae_ft_in1k_start287_end297_32
5+
samples/timm/convnextv2_base.fcmae_ft_in1k/_decomposed/convnextv2_base.fcmae_ft_in1k_start401_end411_44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"samples/timm/convnextv2_base.fcmae_ft_in1k": [
3+
[
4+
287,
5+
297
6+
]
7+
]
8+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"samples/timm/convnextv2_base.fcmae_ft_in1k": [
3+
[
4+
401,
5+
411
6+
]
7+
]
8+
}

0 commit comments

Comments
 (0)