Skip to content

Commit ca985cb

Browse files
committed
Skip the error model for graph variable rename
1 parent 51e558f commit ca985cb

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

graph_net/torch/graph_variable_renamer.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,25 @@ def _make_config(
7575
}
7676

7777
def __call__(self, rel_model_path):
78-
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
79-
module, inputs = get_torch_module_and_inputs(src_model_path)
80-
gm = parse_sole_graph_module(module, inputs)
81-
gm = self.rename_graph_variables(gm, inputs, src_model_path)
82-
dst_model_path = os.path.realpath(
83-
os.path.join(self.config["output_dir"], rel_model_path)
84-
)
85-
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
86-
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)
87-
self._update_model_py_file(gm, dst_model_path)
88-
self._update_weight_meta_py_file(src_model_path, dst_model_path)
89-
self._update_input_meta_py_file(src_model_path, dst_model_path)
90-
self._try_run(dst_model_path)
78+
try:
79+
src_model_path = os.path.join(
80+
self.config["model_path_prefix"], rel_model_path
81+
)
82+
module, inputs = get_torch_module_and_inputs(src_model_path)
83+
gm = parse_sole_graph_module(module, inputs)
84+
gm = self.rename_graph_variables(gm, inputs, src_model_path)
85+
dst_model_path = os.path.realpath(
86+
os.path.join(self.config["output_dir"], rel_model_path)
87+
)
88+
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
89+
shutil.copytree(src_model_path, dst_model_path, dirs_exist_ok=True)
90+
self._update_model_py_file(gm, dst_model_path)
91+
self._update_weight_meta_py_file(src_model_path, dst_model_path)
92+
self._update_input_meta_py_file(src_model_path, dst_model_path)
93+
self._try_run(dst_model_path)
94+
except Exception:
95+
print("Failed to rename variables of ", src_model_path)
96+
print("Skipping this model and continuing...\n")
9197

9298
def _try_run(self, model_path):
9399
assert self.model_runnable_predicator(

0 commit comments

Comments
 (0)