Skip to content

Commit 977fcb5

Browse files
committed
support resume in GraphVariableRenamer
1 parent 8bfaa56 commit 977fcb5

File tree

3 files changed

+24
-70
lines changed

3 files changed

+24
-70
lines changed

graph_net/test/level5_subgraph_dataset_test.sh

Lines changed: 0 additions & 63 deletions
This file was deleted.

graph_net/tools/typical_sequence_decompose.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ python3 -m graph_net.model_path_handler \
6868
"handler_path": "$GRAPH_NET_ROOT/graph_net/torch/graph_variable_renamer.py",
6969
"handler_class_name": "GraphVariableRenamer",
7070
"handler_config": {
71+
"resume": true,
7172
"model_path_prefix": "$DECOMPOSE_WORKSPACE",
7273
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
7374
"data_input_predicator_class_name": "NaiveDataInputPredicator",

graph_net/torch/graph_variable_renamer.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import shutil
44
import inspect
5+
import tempfile
56
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
67
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
78
from graph_net.tensor_meta import TensorMeta
@@ -37,8 +38,9 @@ def _make_model_runnable_predicator(self, config):
3738

3839
def _make_config(
3940
self,
40-
data_input_predicator_filepath,
41-
model_runnable_predicator_filepath,
41+
resume: bool = False,
42+
data_input_predicator_filepath=None,
43+
model_runnable_predicator_filepath=None,
4244
output_dir="./tmp/graph_variable_renamer_dir",
4345
filter_path=None,
4446
filter_config=None,
@@ -59,6 +61,7 @@ def _make_config(
5961
if model_runnable_predicator_config is None:
6062
model_runnable_predicator_config = {}
6163
return {
64+
"resume": resume,
6265
"output_dir": output_dir,
6366
"filter_path": filter_path,
6467
"filter_config": filter_config if filter_config is not None else {},
@@ -82,12 +85,25 @@ def __call__(self, rel_model_path):
8285
dst_model_path = os.path.realpath(
8386
os.path.join(self.config["output_dir"], rel_model_path)
8487
)
88+
if self.config["resume"] and os.path.exists(dst_model_path):
89+
return
8590
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
86-
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)
87-
self._update_model_py_file(gm, dst_model_path)
88-
self._update_weight_meta_py_file(src_model_path, dst_model_path)
89-
self._update_input_meta_py_file(src_model_path, dst_model_path)
90-
self._try_run(dst_model_path)
91+
temp_dir = tempfile.mkdtemp(prefix="graph_variable_renamer_")
92+
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))
93+
try:
94+
shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True)
95+
self._update_model_py_file(gm, temp_model_path)
96+
self._update_weight_meta_py_file(src_model_path, temp_model_path)
97+
self._update_input_meta_py_file(src_model_path, temp_model_path)
98+
print("Try to run renamed model...")
99+
self._try_run(temp_model_path)
100+
if os.path.exists(dst_model_path):
101+
shutil.rmtree(dst_model_path)
102+
shutil.copytree(temp_model_path, dst_model_path)
103+
except Exception as e:
104+
raise RuntimeError(f"Failed to handle {src_model_path}: {e}")
105+
finally:
106+
shutil.rmtree(temp_dir, ignore_errors=True)
91107

92108
def _try_run(self, model_path):
93109
assert self.model_runnable_predicator(

0 commit comments

Comments
 (0)