|
5 | 5 | from typing import Any, Dict, List |
6 | 6 | import torch |
7 | 7 | import torch.nn as nn |
| 8 | + |
8 | 9 | from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser |
9 | 10 | from graph_net.torch.rp_expr.rp_expr_util import ( |
10 | 11 | MakeNestedIndexRangeFromLetsListTokenRpExpr, |
11 | 12 | ) |
12 | 13 | from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs |
13 | 14 | from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module_without_varify |
| 15 | +from graph_net.torch.decompose_util import cuda_gc |
14 | 16 |
|
15 | 17 |
|
16 | 18 | class TypicalSequenceExtractor: |
@@ -85,10 +87,12 @@ def _get_output_path(self, rel_model_path: str): |
85 | 87 |
|
86 | 88 | def _extract_ops(self, model_path: str) -> List[str]: |
87 | 89 | extractor = TypicalSequenceExtractor() |
88 | | - model, inputs = get_torch_module_and_inputs(model_path) |
89 | | - compiled_model, _ = parse_sole_graph_module_without_varify(model, inputs) |
90 | | - extractor.extract_compiler(compiled_model, inputs) |
91 | | - ops_info = extractor.extract_node |
| 90 | + with cuda_gc(): |
| 91 | + model, inputs = get_torch_module_and_inputs(model_path) |
| 92 | + compiled_model, _ = parse_sole_graph_module_without_varify(model, inputs) |
| 93 | + extractor.extract_compiler(compiled_model, inputs) |
| 94 | + ops_info = extractor.extract_node |
| 95 | + del model, inputs, compiled_model |
92 | 96 |
|
93 | 97 | return [op["target_name"] for op in ops_info] |
94 | 98 |
|
|
0 commit comments