Skip to content

Commit a992aa5

Browse files
authored
Add torch.cuda.empty_cache() in decompose process (#455)
* Add cuda_gc in decomposer and var renamer * Add cuda_gc in typical_sequence_split_points * Fix resume in graph_variable_renamer to early exit * Minor fix * Simplify
1 parent 908bfb8 commit a992aa5

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

graph_net/torch/graph_decomposer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import json
66
import sys
7+
78
from graph_net.torch.decompose_util import convert_to_submodules_graph
89
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
910
import graph_net.imp_util as imp_util
@@ -12,15 +13,15 @@
1213
parse_immutable_model_path_into_sole_graph_module,
1314
)
1415
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
16+
1517
import logging
1618

1719
logger = logging.getLogger(__name__)
1820

1921

2022
def load_json(file_path):
21-
with open(file_path, "r", encoding="utf-8") as file:
22-
data_dict = json.load(file)
23-
return data_dict
23+
with open(file_path, "r", encoding="utf-8") as f:
24+
return json.load(f)
2425

2526

2627
class GraphExtractor:
@@ -242,19 +243,19 @@ def __call__(self, rel_model_path):
242243
rel_model_path, split_positions, subgraph_ranges
243244
):
244245
return
246+
245247
torch.cuda.empty_cache()
246-
config = {
247-
"split_positions": split_positions,
248-
"subgraph_ranges": subgraph_ranges,
249-
"group_head_and_tail": self.config.get("group_head_and_tail", False),
250-
"chain_style": self.config.get("chain_style", False),
251-
}
252248
module, inputs = get_torch_module_and_inputs(model_path, use_dummy_inputs=False)
253249
gm = parse_sole_graph_module(module, inputs)
250+
251+
torch.cuda.empty_cache()
254252
rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph(
255253
gm,
256254
submodule_hook=self.get_naive_decomposer_extractor(rel_model_path),
257-
**config,
255+
split_positions=split_positions,
256+
subgraph_ranges=subgraph_ranges,
257+
group_head_and_tail=self.config.get("group_head_and_tail", False),
258+
chain_style=self.config.get("chain_style", False),
258259
)
259260
rewrited_gm(*inputs)
260261

graph_net/torch/graph_variable_renamer.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import shutil
44
import tempfile
5+
56
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
67
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
78
from graph_net.tensor_meta import TensorMeta
@@ -77,17 +78,21 @@ def _make_config(
7778
}
7879

7980
def __call__(self, rel_model_path):
80-
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
81-
module, inputs = get_torch_module_and_inputs(src_model_path)
82-
gm = parse_sole_graph_module(module, inputs)
83-
gm, rename_map = self.rename_graph_variables(gm, inputs, src_model_path)
81+
torch.cuda.empty_cache()
82+
8483
dst_model_path = os.path.realpath(
8584
os.path.join(self.config["output_dir"], rel_model_path)
8685
)
8786
if self.config["resume"] and os.path.exists(
8887
os.path.join(dst_model_path, "model.py")
8988
):
9089
return
90+
91+
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
92+
module, inputs = get_torch_module_and_inputs(src_model_path)
93+
gm = parse_sole_graph_module(module, inputs)
94+
gm, rename_map = self.rename_graph_variables(gm, inputs, src_model_path)
95+
9196
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
9297
with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir:
9398
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))
@@ -97,8 +102,8 @@ def __call__(self, rel_model_path):
97102
src_model_path, temp_model_path, rename_map
98103
)
99104
self._update_input_meta_py_file(src_model_path, temp_model_path, rename_map)
100-
print("Try to run renamed model...")
101-
self._try_run(temp_model_path)
105+
# print("Try to run renamed model...")
106+
# self._try_run(temp_model_path)
102107
shutil.copytree(temp_model_path, dst_model_path)
103108

104109
def _try_run(self, model_path):

graph_net/torch/typical_sequence_split_points.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Dict, List
66
import torch
77
import torch.nn as nn
8+
89
from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser
910
from graph_net.torch.rp_expr.rp_expr_util import (
1011
MakeNestedIndexRangeFromLetsListTokenRpExpr,
@@ -69,6 +70,7 @@ def _make_config(
6970
}
7071

7172
def __call__(self, rel_model_path: str):
73+
torch.cuda.empty_cache()
7274
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
7375
output_path = self._get_output_path(rel_model_path)
7476
if self.config["resume"] and output_path.exists():

0 commit comments

Comments
 (0)