Skip to content

Commit fb7ecd6

Browse files
committed
Add cuda_gc in decomposer and var renamer
1 parent 9a55a74 commit fb7ecd6

File tree

3 files changed

+44
-20
lines changed

3 files changed

+44
-20
lines changed

graph_net/torch/decompose_util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
import torch
22
import copy
33
import operator
4+
import gc
5+
from contextlib import contextmanager
46
from collections import defaultdict
57
from dataclasses import dataclass
68

79

10+
@contextmanager
11+
def cuda_gc(enabled: bool = True):
12+
try:
13+
yield
14+
finally:
15+
if enabled:
16+
gc.collect()
17+
torch.cuda.empty_cache()
18+
19+
820
def convert_to_submodules_graph(
921
gm: torch.fx.GraphModule,
1022
split_positions: list[int],

graph_net/torch/graph_decomposer.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,24 @@
44
import torch
55
import json
66
import sys
7-
from graph_net.torch.decompose_util import convert_to_submodules_graph
7+
8+
from graph_net.torch.decompose_util import convert_to_submodules_graph, cuda_gc
89
from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor
910
import graph_net.imp_util as imp_util
1011
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
1112
from graph_net.torch.fx_graph_cache_util import (
1213
parse_immutable_model_path_into_sole_graph_module,
1314
)
1415
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
16+
1517
import logging
1618

1719
logger = logging.getLogger(__name__)
1820

1921

2022
def load_json(file_path):
21-
with open(file_path, "r", encoding="utf-8") as file:
22-
data_dict = json.load(file)
23-
return data_dict
23+
with open(file_path, "r", encoding="utf-8") as f:
24+
return json.load(f)
2425

2526

2627
class GraphExtractor:
@@ -221,20 +222,27 @@ def __call__(self, rel_model_path):
221222
rel_model_path, split_positions
222223
):
223224
return
224-
torch.cuda.empty_cache()
225-
config = {
226-
"split_positions": split_positions,
227-
"group_head_and_tail": self.config.get("group_head_and_tail", False),
228-
"chain_style": self.config.get("chain_style", False),
229-
}
230-
module, inputs = get_torch_module_and_inputs(model_path, use_dummy_inputs=False)
225+
226+
with cuda_gc():
227+
module, inputs = get_torch_module_and_inputs(
228+
model_path, use_dummy_inputs=False
229+
)
231230
gm = parse_sole_graph_module(module, inputs)
232-
rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph(
233-
gm,
234-
submodule_hook=self.get_naive_decomposer_extractor(rel_model_path),
235-
**config,
236-
)
237-
rewrited_gm(*inputs)
231+
del module
232+
233+
with cuda_gc():
234+
rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph(
235+
gm,
236+
submodule_hook=self.get_naive_decomposer_extractor(rel_model_path),
237+
split_positions=split_positions,
238+
group_head_and_tail=self.config.get("group_head_and_tail", False),
239+
chain_style=self.config.get("chain_style", False),
240+
)
241+
rewrited_gm(*inputs)
242+
del inputs, rewrited_gm
243+
244+
with cuda_gc():
245+
pass
238246

239247
def get_naive_decomposer_extractor(self, rel_model_path):
240248
def fn(submodule, seq_no):

graph_net/torch/graph_variable_renamer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import shutil
44
import inspect
55
import tempfile
6+
67
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
78
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
89
from graph_net.tensor_meta import TensorMeta
910
from pathlib import Path
1011
from graph_net.torch.utils import apply_templates
1112
from graph_net.imp_util import load_module
1213
from graph_net.hash_util import get_sha256_hash
14+
from graph_net.torch.decompose_util import cuda_gc
1315

1416

1517
class GraphVariableRenamer:
@@ -79,9 +81,11 @@ def _make_config(
7981

8082
def __call__(self, rel_model_path):
8183
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
82-
module, inputs = get_torch_module_and_inputs(src_model_path)
83-
gm = parse_sole_graph_module(module, inputs)
84-
gm = self.rename_graph_variables(gm, inputs, src_model_path)
84+
with cuda_gc(enabled=self.config["release_gpu_memory"]):
85+
module, inputs = get_torch_module_and_inputs(src_model_path)
86+
gm = parse_sole_graph_module(module, inputs)
87+
gm = self.rename_graph_variables(gm, inputs, src_model_path)
88+
del module, inputs
8589
dst_model_path = os.path.realpath(
8690
os.path.join(self.config["output_dir"], rel_model_path)
8791
)

0 commit comments

Comments
 (0)