Skip to content

Commit 138e1c0

Browse files
committed
add DeviceRewriteSamplePass
1 parent 57d92b5 commit 138e1c0

File tree

9 files changed

+271
-6
lines changed

9 files changed

+271
-6
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
model_path_handler_config_json_str=$(cat <<EOF
5+
{
6+
"handler_path": "$GRAPH_NET_ROOT/graph_net/customize_your_sample_pass.py",
7+
"handler_class_name": "customize_your_class_name",
8+
"handler_config": {
9+
"resume": true,
10+
"model_path_prefix": "/customize_your_model_path_prefix",
11+
"output_dir": "/customize_your_output_file"
12+
}
13+
}
14+
EOF
15+
)
16+
17+
model_path_handler_model_path_list="customize_your_model_path_list"
18+
MODEL_PATH_HANDLER_CONFIG=$(echo $model_path_handler_config_json_str | base64 -w 0)
19+
20+
python3 -m graph_net.model_path_handler \
21+
--model-path-list $model_path_handler_model_path_list \
22+
--handler-config $MODEL_PATH_HANDLER_CONFIG \
23+
24+
unset model_path_handler_model_path_list
25+
unset MODEL_PATH_HANDLER_CONFIG
26+
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import abc
2+
import shutil
3+
from graph_net.sample_pass.sample_pass_mixin import SamplePassMixin
4+
from pathlib import Path
5+
6+
7+
class OnlyModelFileRewriteSamplePassMixin(SamplePassMixin):
8+
def declare_config(
9+
self,
10+
model_path_prefix: str,
11+
output_dir: str,
12+
):
13+
pass
14+
15+
@abc.abstractmethod
16+
def handle_model_py_file(self, rel_model_path: str) -> str:
17+
"""
18+
return rewrited model.py file contents
19+
"""
20+
raise NotImplementedError()
21+
22+
def copy_sample_and_handle_model_py_file(self, rel_model_path: str):
23+
src_model_path = Path(self.config["model_path_prefix"]) / rel_model_path
24+
dst_model_path = Path(self.config["output_dir"]) / rel_model_path
25+
dst_model_path.mkdir(parents=True, exist_ok=True)
26+
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)
27+
model_py_code = self.handle_model_py_file(rel_model_path)
28+
(dst_model_path / "model.py").write_text(model_py_code)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import abc
2+
import sys
3+
from graph_net.sample_pass.sample_pass_mixin import SamplePassMixin
4+
from pathlib import Path
5+
import os
6+
7+
8+
class ResumableSamplePassMixin(SamplePassMixin):
9+
def __init__(self, *args, **kwargs):
10+
self.num_handled_models = 0
11+
12+
def declare_config(
13+
self,
14+
model_path_prefix: str,
15+
output_dir: str,
16+
resume: bool = False,
17+
limits_handled_models: int = None,
18+
):
19+
pass
20+
21+
def sample_handled(self, rel_model_path: str) -> bool:
22+
dst_model_path = Path(self.config["output_dir"]) / rel_model_path
23+
if not dst_model_path.exists():
24+
return False
25+
num_model_py_files = len(list(dst_model_path.rglob("model.py")))
26+
assert num_model_py_files <= 1
27+
return num_model_py_files == 1
28+
29+
@abc.abstractmethod
30+
def resume(self, rel_model_path: str):
31+
raise NotImplementedError()
32+
33+
def resumable_handle_sample(self, rel_model_path: str):
34+
assert os.path.realpath(self.config["model_path_prefix"]) != os.path.realpath(
35+
self.config["output_dir"]
36+
)
37+
if self.config["resume"] and self.sample_handled(rel_model_path):
38+
return
39+
self.resume(rel_model_path)
40+
self._inc_num_handled_models_or_exit()
41+
42+
def _inc_num_handled_models_or_exit(self):
43+
if self.config["limits_handled_models"] is None:
44+
return
45+
self.num_handled_models += 1
46+
if self.num_handled_models >= self.config["limits_handled_models"]:
47+
print("limits_handled_models expired.", flush=True)
48+
sys.exit(0)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import abc
2+
import copy
3+
import inspect
4+
5+
6+
class SamplePass(abc.ABC):
7+
def __init__(self, config=None):
8+
if config is None:
9+
config = {}
10+
11+
self._check_config_declaration_valid()
12+
self.config = self._make_config_by_config_declare(config)
13+
14+
@abc.abstractmethod
15+
def declare_config(self):
16+
raise NotImplementedError()
17+
18+
@abc.abstractmethod
19+
def __call__(self, rel_model_path: str):
20+
raise NotImplementedError()
21+
22+
def _recursively_check_mixin_declare_config(self, base_class):
23+
from graph_net.sample_pass.sample_pass_mixin import SamplePassMixin
24+
25+
if issubclass(base_class, (SamplePass, SamplePassMixin)):
26+
check_is_base_signature(
27+
base_class=base_class,
28+
derived_class=type(self),
29+
method_name="declare_config",
30+
)
31+
for sub_class in base_class.__bases__:
32+
self._recursively_check_mixin_declare_config(sub_class)
33+
34+
def _check_config_declaration_parameters(self):
35+
sig = inspect.signature(self.declare_config)
36+
for name, param in sig.parameters.items():
37+
assert param.annotation in {
38+
int,
39+
bool,
40+
float,
41+
str,
42+
list,
43+
dict,
44+
}, f"{name=} {param.annotation}"
45+
assert param.kind in {
46+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
47+
inspect.Parameter.VAR_KEYWORD,
48+
}, f"{name=} {param.kind=}"
49+
50+
def _check_config_declaration_valid(self):
51+
self._recursively_check_mixin_declare_config(type(self))
52+
self._check_config_declaration_parameters()
53+
54+
def _make_config_by_config_declare(self, config):
55+
sig = inspect.signature(self.declare_config)
56+
mut_config = copy.deepcopy(config)
57+
for name, param in sig.parameters.items():
58+
self._complete_default(name, param, mut_config)
59+
class_name = type(self).__name__
60+
assert name in mut_config, f"{name=} {class_name=}"
61+
62+
def get_extra_config_fields():
63+
return set(name for name, _ in mut_config.items()) - set(
64+
name for name, _ in sig.parameters.items()
65+
)
66+
67+
no_varadic_keyword = all(
68+
param.kind != inspect.Parameter.VAR_KEYWORD
69+
for _, param in sig.parameters.items()
70+
)
71+
if no_varadic_keyword:
72+
no_extra_config_fields = all(
73+
name in sig.parameters for name, _ in mut_config.items()
74+
)
75+
assert no_extra_config_fields, f"{get_extra_config_fields()=}"
76+
return mut_config
77+
78+
def _complete_default(self, name, param, mut_config):
79+
if param.default is inspect.Parameter.empty:
80+
return
81+
mut_config[name] = copy.deepcopy(param.default)
82+
83+
84+
def check_is_base_signature(base_class, derived_class, method_name):
85+
base = getattr(base_class, method_name)
86+
derived = getattr(derived_class, method_name)
87+
base_parameters = inspect.signature(base).parameters
88+
derived_parameters = inspect.signature(derived).parameters
89+
assert len(derived_parameters) >= len(base_parameters)
90+
for name, param in base_parameters.items():
91+
assert name in base_parameters, f"{name=}"
92+
assert param == base_parameters[name]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import abc
2+
3+
4+
class SamplePassMixin(abc.ABC):
5+
@abc.abstractmethod
6+
def declare_config(self):
7+
raise NotImplementedError()
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/bin/bash
2+
3+
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
model_path_handler_config_json_str=$(cat <<EOF
5+
{
6+
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/sample_passes/device_rewrite_sample_pass.py",
7+
"handler_class_name": "DeviceRewriteSamplePass",
8+
"handler_config": {
9+
"device": "cuda",
10+
"resume": false,
11+
"model_path_prefix": "$GRAPH_NET_ROOT",
12+
"output_dir": "/tmp/device_rewrited"
13+
}
14+
}
15+
EOF
16+
)
17+
18+
model_path_handler_model_path_list="$GRAPH_NET_ROOT/graph_net/test/dev_model_list/validation_error_model_list.txt"
19+
MODEL_PATH_HANDLER_CONFIG=$(echo $model_path_handler_config_json_str | base64 -w 0)
20+
21+
python3 -m graph_net.model_path_handler \
22+
--model-path-list $model_path_handler_model_path_list \
23+
--handler-config $MODEL_PATH_HANDLER_CONFIG \
24+
25+
unset model_path_handler_model_path_list
26+
unset MODEL_PATH_HANDLER_CONFIG
27+

graph_net/torch/fx_graph_parse_util.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,6 @@ def parse_sole_graph_module_without_varify(module, inputs):
130130
def my_backend(gm, sample_inputs):
131131
nonlocal traced_module
132132
nonlocal traced_sample_inputs
133-
assert traced_module is None
134-
assert traced_sample_inputs is None
135133
traced_module = gm
136134
traced_sample_inputs = sample_inputs
137135
return gm.forward

graph_net/torch/graph_decomposer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import shutil
23
from pathlib import Path
34
import torch
45
import json
@@ -197,13 +198,15 @@ def _is_model_handled(self, rel_model_path, split_positions):
197198
num_subgraphs = len(split_positions) + 1
198199
decomposed_model_path = Path(self.config["output_dir"]) / rel_model_path
199200
num_decomposed = len(list(decomposed_model_path.rglob("model.py")))
200-
if num_decomposed > 0:
201-
assert (
202-
num_subgraphs <= num_decomposed
203-
), f"{num_subgraphs=} {num_decomposed=} {str(decomposed_model_path)=}"
201+
if num_decomposed > 0 and num_subgraphs != num_decomposed:
202+
shutil.rmtree(decomposed_model_path / "_decomposed")
203+
return False
204204
return num_subgraphs == num_decomposed
205205

206206
def __call__(self, rel_model_path):
207+
assert os.path.realpath(self.config["model_path_prefix"]) != os.path.realpath(
208+
self.config["output_dir"]
209+
)
207210
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
208211
split_results = load_json(self.config["split_results_path"])
209212
split_positions = split_results[rel_model_path]["split_positions"]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from graph_net.sample_pass.sample_pass import SamplePass
2+
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
3+
from graph_net.sample_pass.only_model_file_rewrite_sample_pass_mixin import (
4+
OnlyModelFileRewriteSamplePassMixin,
5+
)
6+
from graph_net.torch import utils
7+
from pathlib import Path
8+
9+
10+
class DeviceRewriteSamplePass(
11+
SamplePass, ResumableSamplePassMixin, OnlyModelFileRewriteSamplePassMixin
12+
):
13+
def __init__(self, config):
14+
super().__init__(config)
15+
16+
def declare_config(
17+
self,
18+
model_path_prefix: str,
19+
output_dir: str,
20+
device: str,
21+
resume: bool = False,
22+
limits_handled_models: int = None,
23+
):
24+
pass
25+
26+
def __call__(self, rel_model_path: str):
27+
self.resumable_handle_sample(rel_model_path)
28+
29+
def resume(self, rel_model_path: str):
30+
return self.copy_sample_and_handle_model_py_file(rel_model_path)
31+
32+
def handle_model_py_file(self, rel_model_path: str) -> str:
33+
src_model_path = Path(self.config["model_path_prefix"]) / rel_model_path
34+
model_py_code = (src_model_path / "model.py").read_text()
35+
device = self.config["device"]
36+
return utils.modify_code_by_device(model_py_code, device)

0 commit comments

Comments
 (0)