Skip to content

Commit 6c46a71

Browse files
committed
Merge branch 'develop' into support_meta_restore
2 parents 65c21f3 + b203ea1 commit 6c46a71

15 files changed

+764
-129
lines changed

graph_net/optional.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import TypeVar, Generic, Union
2+
3+
T = TypeVar("T")
4+
5+
6+
class Optional(Generic[T]):
7+
def __init__(self, value: Union[T, None]):
8+
self._value = value
9+
10+
def reset(self, that):
11+
assert isinstance(that, Optional)
12+
self._value = that._value
13+
14+
def is_some(self) -> bool:
15+
return self._value is not None
16+
17+
def unwrap(self) -> T:
18+
"""Returns the value or raises an error if None."""
19+
if self._value is None:
20+
raise ValueError("Tried to unwrap a None value!")
21+
return self._value
22+
23+
def unwrap_or(self, default: T) -> T:
24+
"""Returns the value or a default if None."""
25+
return self._value if self._value is not None else default

graph_net/sample_pass/agent_unittest_generator.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def __init__(
224224
self.output_dir = Path(output_dir)
225225
self.device = self._choose_device(device)
226226
self.generate_main = generate_main
227-
self.try_run = try_run and generate_main
227+
self.try_run = try_run
228228
self.data_input_predicator = self._make_data_input_predicator(
229229
data_input_predicator_filepath, data_input_predicator_class_name
230230
)
@@ -244,20 +244,26 @@ def generate(self):
244244
input_tensor_metas,
245245
weight_tensor_metas,
246246
) = self._get_input_and_weight_tensor_metas(input_arg_names, weight_arg_names)
247-
graph_module_desc = GraphModuleDescriptor(
248-
device=self.device,
249-
generate_main=self.generate_main,
250-
model_name=model_name,
251-
input_arg_names=input_arg_names,
252-
input_tensor_metas=input_tensor_metas,
253-
weight_arg_names=weight_arg_names,
254-
weight_tensor_metas=weight_tensor_metas,
255-
forward_body=self._get_forward_body(
256-
graph_module, input_arg_names, weight_arg_names
257-
),
258-
)
259-
unittest = self._render_template(graph_module_desc)
260-
if self._try_to_run_unittest(unittest):
247+
248+
def _generate_unittest(generate_main):
249+
graph_module_desc = GraphModuleDescriptor(
250+
device=self.device,
251+
generate_main=generate_main,
252+
model_name=model_name,
253+
input_arg_names=input_arg_names,
254+
input_tensor_metas=input_tensor_metas,
255+
weight_arg_names=weight_arg_names,
256+
weight_tensor_metas=weight_tensor_metas,
257+
forward_body=self._get_forward_body(
258+
graph_module, input_arg_names, weight_arg_names
259+
),
260+
)
261+
return self._render_template(graph_module_desc)
262+
263+
# Generate unittest with main for try-run.
264+
unittest_for_try_run = _generate_unittest(generate_main=self.try_run)
265+
if self._try_to_run_unittest(unittest_for_try_run):
266+
unittest = _generate_unittest(generate_main=self.generate_main)
261267
self._write_to_file(unittest, self.output_dir)
262268

263269
def _choose_device(self, device) -> str:
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from graph_net.sample_pass.sample_pass import SamplePass
2+
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
3+
from pathlib import Path
4+
import json
5+
from itertools import groupby
6+
7+
8+
class FusibleSubgraphRangesGenerator(SamplePass, ResumableSamplePassMixin):
9+
def __init__(self, config):
10+
super().__init__(config)
11+
12+
def declare_config(
13+
self,
14+
model_path_prefix: str,
15+
output_dir: str,
16+
input_json_file_name: str,
17+
resume: bool = False,
18+
limits_handled_models: int = None,
19+
output_json_file_name: str = "fusible_subgraph_ranges.json",
20+
):
21+
pass
22+
23+
def __call__(self, rel_model_path: str):
24+
self.resumable_handle_sample(rel_model_path)
25+
26+
def sample_handled(self, rel_model_path: str) -> bool:
27+
file_name = self.config["output_json_file_name"]
28+
return self.naive_sample_handled(rel_model_path, search_file_name=file_name)
29+
30+
def resume(self, rel_model_path: str):
31+
analyzer = self._make_analyzer(rel_model_path)
32+
output_obj = analyzer.analyze()
33+
self._save_output(rel_model_path, output_obj)
34+
35+
def _save_output(self, rel_model_path, output_obj):
36+
output_json = json.dumps(output_obj, indent=4)
37+
output_dir_path = Path(self.config["output_dir"]) / rel_model_path
38+
output_dir_path.mkdir(parents=True, exist_ok=True)
39+
output_file_path = output_dir_path / self.config["output_json_file_name"]
40+
output_file_path.write_text(output_json)
41+
42+
def _make_analyzer(self, rel_model_path: str):
43+
model_path = (
44+
Path(self.config["model_path_prefix"])
45+
/ rel_model_path
46+
/ self.config["input_json_file_name"]
47+
)
48+
json_ctx = self._make_json_ctx(model_path)
49+
return FusibleSubgraphRangesAnalyzer(
50+
num_subgraph_kernels_list=self._get_num_subgraph_kernels_list(json_ctx),
51+
num_subgraph_ops_list=self._get_num_subgraph_ops_list(json_ctx),
52+
start_offset_in_original_graph=self._get_start_offset_in_original_graph(
53+
json_ctx
54+
),
55+
)
56+
57+
def _get_start_offset_in_original_graph(self, json_ctx):
58+
return json_ctx["start_offset_in_original_graph"]
59+
60+
def _get_num_subgraph_kernels_list(self, json_ctx):
61+
return json_ctx["num_subgraph_kernels"]
62+
63+
def _get_num_subgraph_ops_list(self, json_ctx):
64+
return json_ctx["num_subgraph_ops"]
65+
66+
def _make_json_ctx(self, model_path: Path):
67+
obj = json.loads(model_path.read_text())
68+
assert len(obj["num_subgraph_kernels"]) == len(obj["num_subgraph_ops"])
69+
return obj
70+
71+
72+
class FusibleSubgraphRangesAnalyzer:
73+
def __init__(
74+
self,
75+
num_subgraph_kernels_list: list[int],
76+
num_subgraph_ops_list: list[int],
77+
start_offset_in_original_graph: int,
78+
):
79+
assert len(num_subgraph_kernels_list) == len(num_subgraph_ops_list)
80+
self.num_subgraph_kernels_list = num_subgraph_kernels_list
81+
self.num_subgraph_ops_list = num_subgraph_ops_list
82+
self.start_offset_in_original_graph = start_offset_in_original_graph
83+
84+
def analyze(self):
85+
num_kernels_and_num_ops_list: list[
86+
(int, list[int])
87+
] = self._make_num_kernels_and_num_ops_list()
88+
num_kernels_and_num_ops_list = sorted(
89+
num_kernels_and_num_ops_list, key=lambda pair: pair[0]
90+
)
91+
num_ops_lists = [
92+
sorted(num_ops_list)
93+
for _, num_ops_list in num_kernels_and_num_ops_list
94+
if len(set(num_ops_list)) > 1
95+
]
96+
fusible_subgraph_ranges = [
97+
(start, end)
98+
for num_ops_list in num_ops_lists
99+
for start in [num_ops_list[0] - 1]
100+
for end in [num_ops_list[-1]]
101+
]
102+
# sorted by `start`
103+
fusible_subgraph_ranges = sorted(
104+
fusible_subgraph_ranges, key=lambda pair: pair[0]
105+
)
106+
# remove shadowed
107+
fusible_subgraph_ranges = [
108+
fusible_subgraph_ranges[i]
109+
for i in range(len(fusible_subgraph_ranges))
110+
if i == 0
111+
or (fusible_subgraph_ranges[i][0] >= fusible_subgraph_ranges[i - 1][1])
112+
]
113+
return fusible_subgraph_ranges
114+
115+
def _make_num_kernels_and_num_ops_list(self):
116+
num_kernels_and_num_ops = zip(
117+
self.num_subgraph_kernels_list,
118+
self.num_subgraph_ops_list,
119+
)
120+
121+
def get_num_kernels(pair):
122+
return pair[0]
123+
124+
num_kernels_and_num_ops = sorted(num_kernels_and_num_ops, key=get_num_kernels)
125+
grouped_num_kernels_and_num_ops = groupby(
126+
num_kernels_and_num_ops, key=get_num_kernels
127+
)
128+
num_kernels_and_num_ops_list = [
129+
(num_kernels, [num_ops for _, num_ops in group])
130+
for num_kernels, group in grouped_num_kernels_and_num_ops
131+
]
132+
return num_kernels_and_num_ops_list

graph_net/sample_pass/resumable_sample_pass_mixin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@ def declare_config(
1818
):
1919
pass
2020

21+
@abc.abstractmethod
2122
def sample_handled(self, rel_model_path: str) -> bool:
23+
raise NotImplementedError()
24+
25+
def naive_sample_handled(self, rel_model_path: str, search_file_name: str) -> bool:
2226
dst_model_path = Path(self.config["output_dir"]) / rel_model_path
2327
if not dst_model_path.exists():
2428
return False
25-
num_model_py_files = len(list(dst_model_path.rglob("model.py")))
29+
num_model_py_files = len(list(dst_model_path.rglob(search_file_name)))
2630
assert num_model_py_files <= 1
2731
return num_model_py_files == 1
2832

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
5+
python3 -m graph_net.model_path_handler \
6+
--model-path-list "$GRAPH_NET_ROOT/graph_net/test/dev_model_list/cumsum_num_kernels_sample_list.txt" \
7+
--handler-config $(base64 -w 0 <<EOF
8+
{
9+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/cumsum_num_kernels_generator.py",
10+
"handler_class_name": "CumSumNumKernelsGenerator",
11+
"handler_config": {
12+
"output_json_file_name": "cumsum_num_kernels.json",
13+
"resume": false,
14+
"model_path_prefix": "$GRAPH_NET_ROOT",
15+
"output_dir": "/tmp/cumsum_num_kernels_workspace"
16+
}
17+
}
18+
EOF
19+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
samples/timm/resnet18
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
5+
python3 -m graph_net.model_path_handler \
6+
--model-path-list "$GRAPH_NET_ROOT/graph_net/test/dev_model_list/cumsum_num_kernels_sample_list.txt" \
7+
--handler-config $(base64 -w 0 <<EOF
8+
{
9+
"handler_path": "$GRAPH_NET_ROOT/graph_net/sample_pass/fusible_subgraph_ranges_generator.py",
10+
"handler_class_name": "FusibleSubgraphRangesGenerator",
11+
"handler_config": {
12+
"resume": false,
13+
"model_path_prefix": "$GRAPH_NET_ROOT/graph_net/test/workspace_cumsum_num_kernels",
14+
"input_json_file_name": "cumsum_num_kernels.json",
15+
"output_json_file_name": "fusible_subgraph_ranges.json",
16+
"output_dir": "/tmp/workspace_fusible_subgraph_ranges"
17+
}
18+
}
19+
EOF
20+
)

0 commit comments

Comments
 (0)