|
4 | 4 | import torch |
5 | 5 | import json |
6 | 6 | import sys |
7 | | -from graph_net.torch.decompose_util import convert_to_submodules_graph |
| 7 | + |
| 8 | +from graph_net.torch.decompose_util import convert_to_submodules_graph, cuda_gc |
8 | 9 | from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor |
9 | 10 | import graph_net.imp_util as imp_util |
10 | 11 | from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs |
11 | 12 | from graph_net.torch.fx_graph_cache_util import ( |
12 | 13 | parse_immutable_model_path_into_sole_graph_module, |
13 | 14 | ) |
14 | 15 | from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module |
| 16 | + |
15 | 17 | import logging |
16 | 18 |
|
17 | 19 | logger = logging.getLogger(__name__) |
18 | 20 |
|
19 | 21 |
|
20 | 22 | def load_json(file_path): |
21 | | - with open(file_path, "r", encoding="utf-8") as file: |
22 | | - data_dict = json.load(file) |
23 | | - return data_dict |
| 23 | + with open(file_path, "r", encoding="utf-8") as f: |
| 24 | + return json.load(f) |
24 | 25 |
|
25 | 26 |
|
26 | 27 | class GraphExtractor: |
@@ -221,20 +222,27 @@ def __call__(self, rel_model_path): |
221 | 222 | rel_model_path, split_positions |
222 | 223 | ): |
223 | 224 | return |
224 | | - torch.cuda.empty_cache() |
225 | | - config = { |
226 | | - "split_positions": split_positions, |
227 | | - "group_head_and_tail": self.config.get("group_head_and_tail", False), |
228 | | - "chain_style": self.config.get("chain_style", False), |
229 | | - } |
230 | | - module, inputs = get_torch_module_and_inputs(model_path, use_dummy_inputs=False) |
| 225 | + |
| 226 | + with cuda_gc(): |
| 227 | + module, inputs = get_torch_module_and_inputs( |
| 228 | + model_path, use_dummy_inputs=False |
| 229 | + ) |
231 | 230 | gm = parse_sole_graph_module(module, inputs) |
232 | | - rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph( |
233 | | - gm, |
234 | | - submodule_hook=self.get_naive_decomposer_extractor(rel_model_path), |
235 | | - **config, |
236 | | - ) |
237 | | - rewrited_gm(*inputs) |
| 231 | + del module |
| 232 | + |
| 233 | + with cuda_gc(): |
| 234 | + rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph( |
| 235 | + gm, |
| 236 | + submodule_hook=self.get_naive_decomposer_extractor(rel_model_path), |
| 237 | + split_positions=split_positions, |
| 238 | + group_head_and_tail=self.config.get("group_head_and_tail", False), |
| 239 | + chain_style=self.config.get("chain_style", False), |
| 240 | + ) |
| 241 | + rewrited_gm(*inputs) |
| 242 | + del inputs, rewrited_gm |
| 243 | + |
| 244 | + with cuda_gc(): |
| 245 | + pass |
238 | 246 |
|
239 | 247 | def get_naive_decomposer_extractor(self, rel_model_path): |
240 | 248 | def fn(submodule, seq_no): |
|
0 commit comments