|
| 1 | +from graph_net.sample_pass.sample_pass import SamplePass |
| 2 | +from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin |
| 3 | +from pathlib import Path |
1 | 4 | import os |
2 | 5 | import shutil |
3 | 6 | import tempfile |
4 | 7 | import ast |
5 | 8 | import inspect |
6 | 9 | import torch |
7 | | -from pathlib import Path |
8 | 10 | from graph_net.imp_util import load_module |
9 | 11 | from graph_net.tensor_meta import TensorMeta |
10 | 12 | from graph_net.hash_util import get_sha256_hash |
11 | 13 |
|
12 | 14 |
|
| 15 | +class AstGraphVariableRenamer(SamplePass, ResumableSamplePassMixin): |
| 16 | + def __init__(self, config): |
| 17 | + super().__init__(config) |
| 18 | + self.data_input_predicator = self._make_data_input_predicator(self.config) |
| 19 | + self.model_runnable_predicator = self._make_model_runnable_predicator( |
| 20 | + self.config |
| 21 | + ) |
| 22 | + |
| 23 | + def _make_data_input_predicator(self, config): |
| 24 | + module = load_module(config["data_input_predicator_filepath"]) |
| 25 | + cls = getattr(module, config["data_input_predicator_class_name"]) |
| 26 | + return cls(config["data_input_predicator_config"]) |
| 27 | + |
| 28 | + def _make_model_runnable_predicator(self, config): |
| 29 | + module = load_module(config["model_runnable_predicator_filepath"]) |
| 30 | + cls = getattr(module, config["model_runnable_predicator_class_name"]) |
| 31 | + return cls(config["model_runnable_predicator_config"]) |
| 32 | + |
| 33 | + def declare_config( |
| 34 | + self, |
| 35 | + model_path_prefix: str, |
| 36 | + output_dir: str, |
| 37 | + device: str, |
| 38 | + resume: bool = False, |
| 39 | + limits_handled_models: int = None, |
| 40 | + data_input_predicator_filepath: str = None, |
| 41 | + data_input_predicator_class_name: str = None, |
| 42 | + data_input_predicator_config: dict = None, |
| 43 | + model_runnable_predicator_filepath: str = None, |
| 44 | + model_runnable_predicator_class_name: str = None, |
| 45 | + model_runnable_predicator_config: dict = None, |
| 46 | + ): |
| 47 | + pass |
| 48 | + |
| 49 | + def __call__(self, rel_model_path: str): |
| 50 | + self.resumable_handle_sample(rel_model_path) |
| 51 | + |
| 52 | + def sample_handled(self, rel_model_path: str) -> bool: |
| 53 | + return self.naive_sample_handled(rel_model_path, search_file_name="model.py") |
| 54 | + |
| 55 | + def resume(self, rel_model_path: str): |
| 56 | + torch.cuda.empty_cache() |
| 57 | + dst_model_path = os.path.realpath( |
| 58 | + os.path.join(self.config["output_dir"], rel_model_path) |
| 59 | + ) |
| 60 | + src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path) |
| 61 | + graph_module_class = load_class_from_file( |
| 62 | + os.path.join(src_model_path, "model.py"), class_name="GraphModule" |
| 63 | + ) |
| 64 | + input_arg_names, weight_arg_names = self._get_input_and_weight_arg_names( |
| 65 | + graph_module_class, src_model_path |
| 66 | + ) |
| 67 | + rename_map = self._create_rename_map(input_arg_names, weight_arg_names) |
| 68 | + with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir: |
| 69 | + temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path)) |
| 70 | + shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True) |
| 71 | + self._update_model_py_file( |
| 72 | + temp_model_path, rename_map, input_arg_names, weight_arg_names |
| 73 | + ) |
| 74 | + self._update_meta_file(temp_model_path, "weight_meta.py", rename_map) |
| 75 | + self._update_meta_file(temp_model_path, "input_meta.py", rename_map) |
| 76 | + self._try_run(temp_model_path) |
| 77 | + shutil.copytree(temp_model_path, dst_model_path, dirs_exist_ok=True) |
| 78 | + |
| 79 | + def _get_input_and_weight_arg_names(self, graph_module, model_path): |
| 80 | + input_arg_names = [] |
| 81 | + weight_arg_names = [] |
| 82 | + sig = inspect.signature(graph_module.forward) |
| 83 | + for name, param in sig.parameters.items(): |
| 84 | + if name == "self": |
| 85 | + continue |
| 86 | + is_not_data_input = not self.data_input_predicator(model_path, name) |
| 87 | + if is_not_data_input: |
| 88 | + weight_arg_names.append(name) |
| 89 | + else: |
| 90 | + input_arg_names.append(name) |
| 91 | + return input_arg_names, weight_arg_names |
| 92 | + |
| 93 | + def _create_rename_map(self, input_arg_names, weight_arg_names): |
| 94 | + rename_map = {} |
| 95 | + for idx, name in enumerate(input_arg_names): |
| 96 | + rename_map[name] = f"in_{idx}" |
| 97 | + for idx, name in enumerate(weight_arg_names): |
| 98 | + rename_map[name] = f"w_{idx}" |
| 99 | + return rename_map |
| 100 | + |
| 101 | + def _update_model_py_file( |
| 102 | + self, model_path, rename_map, input_arg_names, weight_arg_names |
| 103 | + ): |
| 104 | + model_file = Path(model_path) / "model.py" |
| 105 | + source = model_file.read_text(encoding="utf-8") |
| 106 | + tree = ast.parse(source) |
| 107 | + node = self._get_graph_module_ast(tree) |
| 108 | + graph_renamer = AstGraphRenamer(rename_map, input_arg_names, weight_arg_names) |
| 109 | + graph_renamer.visit(node) |
| 110 | + py_code = ast.unparse(tree) |
| 111 | + model_file.write_text(py_code, encoding="utf-8") |
| 112 | + file_hash = get_sha256_hash(py_code) |
| 113 | + (Path(model_path) / "graph_hash.txt").write_text(file_hash) |
| 114 | + |
| 115 | + def _get_graph_module_ast(self, tree): |
| 116 | + for node in tree.body: |
| 117 | + if isinstance(node, ast.ClassDef) and node.name == "GraphModule": |
| 118 | + return node |
| 119 | + return None |
| 120 | + |
| 121 | + def _update_meta_file(self, model_path, meta_filename, rename_map): |
| 122 | + meta_file = Path(model_path) / meta_filename |
| 123 | + tensor_metas = TensorMeta.unserialize_from_py_file(str(meta_file)) |
| 124 | + for meta in tensor_metas: |
| 125 | + assert ( |
| 126 | + meta.name in rename_map |
| 127 | + ), f"[Warning] {meta.name} in {meta_filename} not found in rename_map." |
| 128 | + if meta.original_name is None: |
| 129 | + meta.original_name = meta.name |
| 130 | + meta.name = rename_map[meta.name] |
| 131 | + |
| 132 | + py_code = "\n\n".join([meta.serialize_to_py_str() for meta in tensor_metas]) |
| 133 | + meta_file.write_text(py_code) |
| 134 | + |
| 135 | + def _try_run(self, model_path): |
| 136 | + (f"[AstGraphVariableRenamer] Try to run {model_path}") |
| 137 | + assert self.model_runnable_predicator( |
| 138 | + model_path |
| 139 | + ), f"{model_path} is not a runnable model" |
| 140 | + |
| 141 | + |
| 142 | +def load_class_from_file(file_path: str, class_name: str): |
| 143 | + print(f"Load {class_name} from {file_path}") |
| 144 | + module = load_module(file_path, "unnamed_graph_module") |
| 145 | + model_class = getattr(module, class_name, None) |
| 146 | + return model_class |
| 147 | + |
| 148 | + |
13 | 149 | class AstGraphRenamer(ast.NodeTransformer): |
14 | 150 | def __init__(self, rename_map, input_arg_names, weight_arg_names): |
15 | 151 | self.rename_map = rename_map |
@@ -135,159 +271,3 @@ def visit_Name(self, node): |
135 | 271 | if node.id in self.rename_map: |
136 | 272 | return ast.Name(id=self.rename_map[node.id], ctx=node.ctx) |
137 | 273 | return node |
138 | | - |
139 | | - |
140 | | -def load_class_from_file(file_path: str, class_name: str): |
141 | | - print(f"Load {class_name} from {file_path}") |
142 | | - module = load_module(file_path, "unnamed_graph_module") |
143 | | - model_class = getattr(module, class_name, None) |
144 | | - return model_class |
145 | | - |
146 | | - |
147 | | -class AstGraphVariableRenamer: |
148 | | - """ |
149 | | - Used by graph_net.model_path_handler |
150 | | - """ |
151 | | - |
152 | | - def __init__(self, config: dict = None): |
153 | | - if config is None: |
154 | | - config = {} |
155 | | - self.config = self._make_config(**config) |
156 | | - self.data_input_predicator = self._make_data_input_predicator(self.config) |
157 | | - self.model_runnable_predicator = self._make_model_runnable_predicator( |
158 | | - self.config |
159 | | - ) |
160 | | - |
161 | | - def _make_data_input_predicator(self, config): |
162 | | - module = load_module(config["data_input_predicator_filepath"]) |
163 | | - cls = getattr(module, config["data_input_predicator_class_name"]) |
164 | | - return cls(config["data_input_predicator_config"]) |
165 | | - |
166 | | - def _make_model_runnable_predicator(self, config): |
167 | | - module = load_module(config["model_runnable_predicator_filepath"]) |
168 | | - cls = getattr(module, config["model_runnable_predicator_class_name"]) |
169 | | - return cls(config["model_runnable_predicator_config"]) |
170 | | - |
171 | | - def _make_config( |
172 | | - self, |
173 | | - resume: bool = False, |
174 | | - data_input_predicator_filepath=None, |
175 | | - model_runnable_predicator_filepath=None, |
176 | | - output_dir="./tmp/graph_variable_renamer_dir", |
177 | | - filter_path=None, |
178 | | - filter_config=None, |
179 | | - data_input_predicator_class_name="DataInputPredicator", |
180 | | - model_runnable_predicator_class_name="ModelRunner", |
181 | | - data_input_predicator_config=None, |
182 | | - model_runnable_predicator_config=None, |
183 | | - model_path_prefix="", |
184 | | - **kwargs, |
185 | | - ): |
186 | | - if data_input_predicator_config is None: |
187 | | - data_input_predicator_config = {} |
188 | | - if model_runnable_predicator_config is None: |
189 | | - model_runnable_predicator_config = {} |
190 | | - return { |
191 | | - "resume": resume, |
192 | | - "output_dir": output_dir, |
193 | | - "filter_path": filter_path, |
194 | | - "filter_config": filter_config if filter_config is not None else {}, |
195 | | - "data_input_predicator_filepath": data_input_predicator_filepath, |
196 | | - "data_input_predicator_class_name": data_input_predicator_class_name, |
197 | | - "data_input_predicator_config": data_input_predicator_config, |
198 | | - "model_runnable_predicator_filepath": model_runnable_predicator_filepath, |
199 | | - "model_runnable_predicator_class_name": model_runnable_predicator_class_name, |
200 | | - "model_runnable_predicator_config": model_runnable_predicator_config, |
201 | | - "model_path_prefix": model_path_prefix, |
202 | | - } |
203 | | - |
204 | | - def __call__(self, rel_model_path): |
205 | | - torch.cuda.empty_cache() |
206 | | - |
207 | | - dst_model_path = os.path.realpath( |
208 | | - os.path.join(self.config["output_dir"], rel_model_path) |
209 | | - ) |
210 | | - if self.config["resume"] and os.path.exists( |
211 | | - os.path.join(dst_model_path, "model.py") |
212 | | - ): |
213 | | - return |
214 | | - |
215 | | - src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path) |
216 | | - Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True) |
217 | | - graph_module_class = load_class_from_file( |
218 | | - os.path.join(src_model_path, "model.py"), class_name="GraphModule" |
219 | | - ) |
220 | | - input_arg_names, weight_arg_names = self._get_input_and_weight_arg_names( |
221 | | - graph_module_class, src_model_path |
222 | | - ) |
223 | | - |
224 | | - rename_map = {} |
225 | | - for idx, name in enumerate(input_arg_names): |
226 | | - rename_map[name] = f"in_{idx}" |
227 | | - for idx, name in enumerate(weight_arg_names): |
228 | | - rename_map[name] = f"w_{idx}" |
229 | | - |
230 | | - with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir: |
231 | | - temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path)) |
232 | | - shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True) |
233 | | - self._update_model_py_file( |
234 | | - temp_model_path, rename_map, input_arg_names, weight_arg_names |
235 | | - ) |
236 | | - self._update_meta_file(temp_model_path, "weight_meta.py", rename_map) |
237 | | - self._update_meta_file(temp_model_path, "input_meta.py", rename_map) |
238 | | - self._try_run(temp_model_path) |
239 | | - shutil.copytree(temp_model_path, dst_model_path, dirs_exist_ok=True) |
240 | | - |
241 | | - def _get_input_and_weight_arg_names(self, graph_module, model_path): |
242 | | - input_arg_names = [] |
243 | | - weight_arg_names = [] |
244 | | - sig = inspect.signature(graph_module.forward) |
245 | | - for name, param in sig.parameters.items(): |
246 | | - if name == "self": |
247 | | - continue |
248 | | - is_not_data_input = not self.data_input_predicator(model_path, name) |
249 | | - if is_not_data_input: |
250 | | - weight_arg_names.append(name) |
251 | | - else: |
252 | | - input_arg_names.append(name) |
253 | | - return input_arg_names, weight_arg_names |
254 | | - |
255 | | - def _update_model_py_file( |
256 | | - self, model_path, rename_map, input_arg_names, weight_arg_names |
257 | | - ): |
258 | | - model_file = Path(model_path) / "model.py" |
259 | | - source = model_file.read_text(encoding="utf-8") |
260 | | - tree = ast.parse(source) |
261 | | - node = self._get_graph_module_ast(tree) |
262 | | - graph_renamer = AstGraphRenamer(rename_map, input_arg_names, weight_arg_names) |
263 | | - graph_renamer.visit(node) |
264 | | - py_code = ast.unparse(tree) |
265 | | - model_file.write_text(py_code, encoding="utf-8") |
266 | | - file_hash = get_sha256_hash(py_code) |
267 | | - (Path(model_path) / "graph_hash.txt").write_text(file_hash) |
268 | | - |
269 | | - def _get_graph_module_ast(self, tree): |
270 | | - for node in tree.body: |
271 | | - if isinstance(node, ast.ClassDef) and node.name == "GraphModule": |
272 | | - return node |
273 | | - return None |
274 | | - |
275 | | - def _update_meta_file(self, model_path, meta_filename, rename_map): |
276 | | - meta_file = Path(model_path) / meta_filename |
277 | | - tensor_metas = TensorMeta.unserialize_from_py_file(str(meta_file)) |
278 | | - for meta in tensor_metas: |
279 | | - assert ( |
280 | | - meta.name in rename_map |
281 | | - ), f"[Warning] {meta.name} in {meta_filename} not found in rename_map." |
282 | | - if meta.original_name is None: |
283 | | - meta.original_name = meta.name |
284 | | - meta.name = rename_map[meta.name] |
285 | | - |
286 | | - py_code = "\n\n".join([meta.serialize_to_py_str() for meta in tensor_metas]) |
287 | | - meta_file.write_text(py_code) |
288 | | - |
289 | | - def _try_run(self, model_path): |
290 | | - (f"[AstGraphVariableRenamer] Try to run {model_path}") |
291 | | - assert self.model_runnable_predicator( |
292 | | - model_path |
293 | | - ), f"{model_path} is not a runnable model" |
0 commit comments