Skip to content

Commit df56c31

Browse files
authored
Add try_run for graph variable renamer (#491)
1 parent e6b5544 commit df56c31

File tree

4 files changed

+10
-3
lines changed

4 files changed

+10
-3
lines changed

graph_net/sample_pass/ast_graph_variable_renamer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def declare_config(
3636
output_dir: str,
3737
device: str,
3838
resume: bool = False,
39+
try_run: bool = False,
3940
limits_handled_models: int = None,
4041
data_input_predicator_filepath: str = None,
4142
data_input_predicator_class_name: str = None,
@@ -73,7 +74,8 @@ def resume(self, rel_model_path: str):
7374
)
7475
self._update_meta_file(temp_model_path, "weight_meta.py", rename_map)
7576
self._update_meta_file(temp_model_path, "input_meta.py", rename_map)
76-
self._try_run(temp_model_path)
77+
if self.config["try_run"]:
78+
self._try_run(temp_model_path)
7779
shutil.copytree(temp_model_path, dst_model_path, dirs_exist_ok=True)
7880

7981
def _get_input_and_weight_arg_names(self, graph_module, model_path):
@@ -133,7 +135,7 @@ def _update_meta_file(self, model_path, meta_filename, rename_map):
133135
meta_file.write_text(py_code)
134136

135137
def _try_run(self, model_path):
136-
(f"[AstGraphVariableRenamer] Try to run {model_path}")
138+
print(f"[AstGraphVariableRenamer] Try to run {model_path}")
137139
assert self.model_runnable_predicator(
138140
model_path
139141
), f"{model_path} is not a runnable model"

graph_net/test/ast_graph_variable_rename_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ python3 -m graph_net.model_path_handler \
1515
"handler_config": {
1616
"device": "cuda",
1717
"resume": true,
18+
"try_run": true,
1819
"model_path_prefix": "$GRAPH_NET_ROOT/",
1920
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
2021
"data_input_predicator_class_name": "NaiveDataInputPredicator",

graph_net/test/graph_variable_rename_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ python3 -m graph_net.model_path_handler \
1515
"handler_config": {
1616
"device": "cuda",
1717
"resume": true,
18+
"try_run": true,
1819
"model_path_prefix": "$GRAPH_NET_ROOT/",
1920
"data_input_predicator_filepath": "$GRAPH_NET_ROOT/graph_net/torch/constraint_util.py",
2021
"data_input_predicator_class_name": "NaiveDataInputPredicator",

graph_net/torch/graph_variable_renamer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def _make_model_runnable_predicator(self, config):
3939
def _make_config(
4040
self,
4141
resume: bool = False,
42+
try_run: bool = False,
4243
data_input_predicator_filepath=None,
4344
model_runnable_predicator_filepath=None,
4445
output_dir="./tmp/graph_variable_renamer_dir",
@@ -58,6 +59,7 @@ def _make_config(
5859
return {
5960
"resume": resume,
6061
"output_dir": output_dir,
62+
"try_run": try_run,
6163
"filter_path": filter_path,
6264
"filter_config": filter_config if filter_config is not None else {},
6365
"data_input_predicator_filepath": data_input_predicator_filepath,
@@ -94,7 +96,8 @@ def __call__(self, rel_model_path):
9496
src_model_path, temp_model_path, rename_map
9597
)
9698
self._update_input_meta_py_file(src_model_path, temp_model_path, rename_map)
97-
self._try_run(temp_model_path)
99+
if self.config["try_run"]:
100+
self._try_run(temp_model_path)
98101
shutil.copytree(temp_model_path, dst_model_path)
99102

100103
def _try_run(self, model_path):

0 commit comments

Comments
 (0)