Skip to content

Commit 3c347e3

Browse files
committed
Use with + NamedTempFile
1 parent 40bd56a commit 3c347e3

File tree

1 file changed

+2
-9
lines changed

1 file changed

+2
-9
lines changed

graph_net/torch/graph_variable_renamer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,22 +90,15 @@ def __call__(self, rel_model_path):
9090
):
9191
return
9292
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
93-
temp_dir = tempfile.mkdtemp(prefix="graph_variable_renamer_")
94-
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))
95-
try:
93+
with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir:
94+
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))
9695
shutil.copytree(src_model_path, temp_model_path, dirs_exist_ok=True)
9796
self._update_model_py_file(gm, temp_model_path)
9897
self._update_weight_meta_py_file(src_model_path, temp_model_path)
9998
self._update_input_meta_py_file(src_model_path, temp_model_path)
10099
print("Try to run renamed model...")
101100
self._try_run(temp_model_path)
102-
if os.path.exists(dst_model_path):
103-
shutil.rmtree(dst_model_path)
104101
shutil.copytree(temp_model_path, dst_model_path)
105-
except Exception as e:
106-
raise RuntimeError(f"Failed to handle {src_model_path}: {e}")
107-
finally:
108-
shutil.rmtree(temp_dir, ignore_errors=True)
109102

110103
def _try_run(self, model_path):
111104
assert self.model_runnable_predicator(

0 commit comments

Comments
 (0)