Skip to content

Commit 40b7f5b

Browse files
committed
refactor AstGraphVariableRenamer with samplepass
1 parent 0d31e5d commit 40b7f5b

File tree

2 files changed

+138
-158
lines changed

2 files changed

+138
-158
lines changed
Lines changed: 137 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,151 @@
1+
from graph_net.sample_pass.sample_pass import SamplePass
2+
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin
3+
from pathlib import Path
14
import os
25
import shutil
36
import tempfile
47
import ast
58
import inspect
69
import torch
7-
from pathlib import Path
810
from graph_net.imp_util import load_module
911
from graph_net.tensor_meta import TensorMeta
1012
from graph_net.hash_util import get_sha256_hash
1113

1214

15+
class AstGraphVariableRenamer(SamplePass, ResumableSamplePassMixin):
16+
def __init__(self, config):
17+
super().__init__(config)
18+
self.data_input_predicator = self._make_data_input_predicator(self.config)
19+
self.model_runnable_predicator = self._make_model_runnable_predicator(
20+
self.config
21+
)
22+
23+
def _make_data_input_predicator(self, config):
24+
module = load_module(config["data_input_predicator_filepath"])
25+
cls = getattr(module, config["data_input_predicator_class_name"])
26+
return cls(config["data_input_predicator_config"])
27+
28+
def _make_model_runnable_predicator(self, config):
29+
module = load_module(config["model_runnable_predicator_filepath"])
30+
cls = getattr(module, config["model_runnable_predicator_class_name"])
31+
return cls(config["model_runnable_predicator_config"])
32+
33+
def declare_config(
34+
self,
35+
model_path_prefix: str,
36+
output_dir: str,
37+
device: str,
38+
resume: bool = False,
39+
limits_handled_models: int = None,
40+
data_input_predicator_filepath: str = None,
41+
data_input_predicator_class_name: str = None,
42+
data_input_predicator_config: dict = None,
43+
model_runnable_predicator_filepath: str = None,
44+
model_runnable_predicator_class_name: str = None,
45+
model_runnable_predicator_config: dict = None,
46+
):
47+
pass
48+
49+
def __call__(self, rel_model_path: str):
50+
self.resumable_handle_sample(rel_model_path)
51+
52+
def sample_handled(self, rel_model_path: str) -> bool:
53+
return self.naive_sample_handled(rel_model_path, search_file_name="model.py")
54+
55+
def resume(self, rel_model_path: str):
56+
torch.cuda.empty_cache()
57+
dst_model_path = os.path.realpath(
58+
os.path.join(self.config["output_dir"], rel_model_path)
59+
)
60+
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
61+
graph_module_class = load_class_from_file(
62+
os.path.join(src_model_path, "model.py"), class_name="GraphModule"
63+
)
64+
input_arg_names, weight_arg_names = self._get_input_and_weight_arg_names(
65+
graph_module_class, src_model_path
66+
)
67+
rename_map = self._create_rename_map(input_arg_names, weight_arg_names)
68+
with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir:
69+
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))
70+
shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True)
71+
self._update_model_py_file(
72+
temp_model_path, rename_map, input_arg_names, weight_arg_names
73+
)
74+
self._update_meta_file(temp_model_path, "weight_meta.py", rename_map)
75+
self._update_meta_file(temp_model_path, "input_meta.py", rename_map)
76+
self._try_run(temp_model_path)
77+
shutil.copytree(temp_model_path, dst_model_path, dirs_exist_ok=True)
78+
79+
def _get_input_and_weight_arg_names(self, graph_module, model_path):
80+
input_arg_names = []
81+
weight_arg_names = []
82+
sig = inspect.signature(graph_module.forward)
83+
for name, param in sig.parameters.items():
84+
if name == "self":
85+
continue
86+
is_not_data_input = not self.data_input_predicator(model_path, name)
87+
if is_not_data_input:
88+
weight_arg_names.append(name)
89+
else:
90+
input_arg_names.append(name)
91+
return input_arg_names, weight_arg_names
92+
93+
def _create_rename_map(self, input_arg_names, weight_arg_names):
94+
rename_map = {}
95+
for idx, name in enumerate(input_arg_names):
96+
rename_map[name] = f"in_{idx}"
97+
for idx, name in enumerate(weight_arg_names):
98+
rename_map[name] = f"w_{idx}"
99+
return rename_map
100+
101+
def _update_model_py_file(
102+
self, model_path, rename_map, input_arg_names, weight_arg_names
103+
):
104+
model_file = Path(model_path) / "model.py"
105+
source = model_file.read_text(encoding="utf-8")
106+
tree = ast.parse(source)
107+
node = self._get_graph_module_ast(tree)
108+
graph_renamer = AstGraphRenamer(rename_map, input_arg_names, weight_arg_names)
109+
graph_renamer.visit(node)
110+
py_code = ast.unparse(tree)
111+
model_file.write_text(py_code, encoding="utf-8")
112+
file_hash = get_sha256_hash(py_code)
113+
(Path(model_path) / "graph_hash.txt").write_text(file_hash)
114+
115+
def _get_graph_module_ast(self, tree):
116+
for node in tree.body:
117+
if isinstance(node, ast.ClassDef) and node.name == "GraphModule":
118+
return node
119+
return None
120+
121+
def _update_meta_file(self, model_path, meta_filename, rename_map):
122+
meta_file = Path(model_path) / meta_filename
123+
tensor_metas = TensorMeta.unserialize_from_py_file(str(meta_file))
124+
for meta in tensor_metas:
125+
assert (
126+
meta.name in rename_map
127+
), f"[Warning] {meta.name} in {meta_filename} not found in rename_map."
128+
if meta.original_name is None:
129+
meta.original_name = meta.name
130+
meta.name = rename_map[meta.name]
131+
132+
py_code = "\n\n".join([meta.serialize_to_py_str() for meta in tensor_metas])
133+
meta_file.write_text(py_code)
134+
135+
def _try_run(self, model_path):
136+
(f"[AstGraphVariableRenamer] Try to run {model_path}")
137+
assert self.model_runnable_predicator(
138+
model_path
139+
), f"{model_path} is not a runnable model"
140+
141+
142+
def load_class_from_file(file_path: str, class_name: str):
143+
print(f"Load {class_name} from {file_path}")
144+
module = load_module(file_path, "unnamed_graph_module")
145+
model_class = getattr(module, class_name, None)
146+
return model_class
147+
148+
13149
class AstGraphRenamer(ast.NodeTransformer):
14150
def __init__(self, rename_map, input_arg_names, weight_arg_names):
15151
self.rename_map = rename_map
@@ -135,159 +271,3 @@ def visit_Name(self, node):
135271
if node.id in self.rename_map:
136272
return ast.Name(id=self.rename_map[node.id], ctx=node.ctx)
137273
return node
138-
139-
140-
def load_class_from_file(file_path: str, class_name: str):
141-
print(f"Load {class_name} from {file_path}")
142-
module = load_module(file_path, "unnamed_graph_module")
143-
model_class = getattr(module, class_name, None)
144-
return model_class
145-
146-
147-
class AstGraphVariableRenamer:
148-
"""
149-
Used by graph_net.model_path_handler
150-
"""
151-
152-
def __init__(self, config: dict = None):
153-
if config is None:
154-
config = {}
155-
self.config = self._make_config(**config)
156-
self.data_input_predicator = self._make_data_input_predicator(self.config)
157-
self.model_runnable_predicator = self._make_model_runnable_predicator(
158-
self.config
159-
)
160-
161-
def _make_data_input_predicator(self, config):
162-
module = load_module(config["data_input_predicator_filepath"])
163-
cls = getattr(module, config["data_input_predicator_class_name"])
164-
return cls(config["data_input_predicator_config"])
165-
166-
def _make_model_runnable_predicator(self, config):
167-
module = load_module(config["model_runnable_predicator_filepath"])
168-
cls = getattr(module, config["model_runnable_predicator_class_name"])
169-
return cls(config["model_runnable_predicator_config"])
170-
171-
def _make_config(
172-
self,
173-
resume: bool = False,
174-
data_input_predicator_filepath=None,
175-
model_runnable_predicator_filepath=None,
176-
output_dir="./tmp/graph_variable_renamer_dir",
177-
filter_path=None,
178-
filter_config=None,
179-
data_input_predicator_class_name="DataInputPredicator",
180-
model_runnable_predicator_class_name="ModelRunner",
181-
data_input_predicator_config=None,
182-
model_runnable_predicator_config=None,
183-
model_path_prefix="",
184-
**kwargs,
185-
):
186-
if data_input_predicator_config is None:
187-
data_input_predicator_config = {}
188-
if model_runnable_predicator_config is None:
189-
model_runnable_predicator_config = {}
190-
return {
191-
"resume": resume,
192-
"output_dir": output_dir,
193-
"filter_path": filter_path,
194-
"filter_config": filter_config if filter_config is not None else {},
195-
"data_input_predicator_filepath": data_input_predicator_filepath,
196-
"data_input_predicator_class_name": data_input_predicator_class_name,
197-
"data_input_predicator_config": data_input_predicator_config,
198-
"model_runnable_predicator_filepath": model_runnable_predicator_filepath,
199-
"model_runnable_predicator_class_name": model_runnable_predicator_class_name,
200-
"model_runnable_predicator_config": model_runnable_predicator_config,
201-
"model_path_prefix": model_path_prefix,
202-
}
203-
204-
def __call__(self, rel_model_path):
205-
torch.cuda.empty_cache()
206-
207-
dst_model_path = os.path.realpath(
208-
os.path.join(self.config["output_dir"], rel_model_path)
209-
)
210-
if self.config["resume"] and os.path.exists(
211-
os.path.join(dst_model_path, "model.py")
212-
):
213-
return
214-
215-
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
216-
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
217-
graph_module_class = load_class_from_file(
218-
os.path.join(src_model_path, "model.py"), class_name="GraphModule"
219-
)
220-
input_arg_names, weight_arg_names = self._get_input_and_weight_arg_names(
221-
graph_module_class, src_model_path
222-
)
223-
224-
rename_map = {}
225-
for idx, name in enumerate(input_arg_names):
226-
rename_map[name] = f"in_{idx}"
227-
for idx, name in enumerate(weight_arg_names):
228-
rename_map[name] = f"w_{idx}"
229-
230-
with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir:
231-
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))
232-
shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True)
233-
self._update_model_py_file(
234-
temp_model_path, rename_map, input_arg_names, weight_arg_names
235-
)
236-
self._update_meta_file(temp_model_path, "weight_meta.py", rename_map)
237-
self._update_meta_file(temp_model_path, "input_meta.py", rename_map)
238-
self._try_run(temp_model_path)
239-
shutil.copytree(temp_model_path, dst_model_path, dirs_exist_ok=True)
240-
241-
def _get_input_and_weight_arg_names(self, graph_module, model_path):
242-
input_arg_names = []
243-
weight_arg_names = []
244-
sig = inspect.signature(graph_module.forward)
245-
for name, param in sig.parameters.items():
246-
if name == "self":
247-
continue
248-
is_not_data_input = not self.data_input_predicator(model_path, name)
249-
if is_not_data_input:
250-
weight_arg_names.append(name)
251-
else:
252-
input_arg_names.append(name)
253-
return input_arg_names, weight_arg_names
254-
255-
def _update_model_py_file(
256-
self, model_path, rename_map, input_arg_names, weight_arg_names
257-
):
258-
model_file = Path(model_path) / "model.py"
259-
source = model_file.read_text(encoding="utf-8")
260-
tree = ast.parse(source)
261-
node = self._get_graph_module_ast(tree)
262-
graph_renamer = AstGraphRenamer(rename_map, input_arg_names, weight_arg_names)
263-
graph_renamer.visit(node)
264-
py_code = ast.unparse(tree)
265-
model_file.write_text(py_code, encoding="utf-8")
266-
file_hash = get_sha256_hash(py_code)
267-
(Path(model_path) / "graph_hash.txt").write_text(file_hash)
268-
269-
def _get_graph_module_ast(self, tree):
270-
for node in tree.body:
271-
if isinstance(node, ast.ClassDef) and node.name == "GraphModule":
272-
return node
273-
return None
274-
275-
def _update_meta_file(self, model_path, meta_filename, rename_map):
276-
meta_file = Path(model_path) / meta_filename
277-
tensor_metas = TensorMeta.unserialize_from_py_file(str(meta_file))
278-
for meta in tensor_metas:
279-
assert (
280-
meta.name in rename_map
281-
), f"[Warning] {meta.name} in {meta_filename} not found in rename_map."
282-
if meta.original_name is None:
283-
meta.original_name = meta.name
284-
meta.name = rename_map[meta.name]
285-
286-
py_code = "\n\n".join([meta.serialize_to_py_str() for meta in tensor_metas])
287-
meta_file.write_text(py_code)
288-
289-
def _try_run(self, model_path):
290-
(f"[AstGraphVariableRenamer] Try to run {model_path}")
291-
assert self.model_runnable_predicator(
292-
model_path
293-
), f"{model_path} is not a runnable model"

graph_net/test/ast_graph_variable_rename_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ python3 -m graph_net.model_path_handler \
1010
--model-path-list $model_list \
1111
--handler-config=$(base64 -w 0 <<EOF
1212
{
13-
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/ast_graph_variable_renamer.py",
13+
"handler_path": "$GRAPH_NET_ROOT/graph_net/sample_pass/ast_graph_variable_renamer.py",
1414
"handler_class_name": "AstGraphVariableRenamer",
1515
"handler_config": {
1616
"device": "cuda",

0 commit comments

Comments
 (0)