@@ -90,22 +90,15 @@ def __call__(self, rel_model_path):
9090 ):
9191 return
9292 Path (dst_model_path ).parent .mkdir (parents = True , exist_ok = True )
93- temp_dir = tempfile .mkdtemp (prefix = "graph_variable_renamer_" )
94- temp_model_path = os .path .join (temp_dir , os .path .basename (dst_model_path ))
95- try :
93+ with tempfile .TemporaryDirectory (prefix = "graph_variable_renamer_" ) as temp_dir :
94+ temp_model_path = os .path .join (temp_dir , os .path .basename (dst_model_path ))
9695 shutil .copytree (src_model_path , temp_model_path , dirs_exist_ok = True )
9796 self ._update_model_py_file (gm , temp_model_path )
9897 self ._update_weight_meta_py_file (src_model_path , temp_model_path )
9998 self ._update_input_meta_py_file (src_model_path , temp_model_path )
10099 print ("Try to run renamed model..." )
101100 self ._try_run (temp_model_path )
102- if os .path .exists (dst_model_path ):
103- shutil .rmtree (dst_model_path )
104101 shutil .copytree (temp_model_path , dst_model_path )
105- except Exception as e :
106- raise RuntimeError (f"Failed to handle { src_model_path } : { e } " )
107- finally :
108- shutil .rmtree (temp_dir , ignore_errors = True )
109102
110103 def _try_run (self , model_path ):
111104 assert self .model_runnable_predicator (
0 commit comments