22import torch
33import shutil
44import inspect
5+ import tempfile
56from graph_net .torch .fx_graph_module_util import get_torch_module_and_inputs
67from graph_net .torch .fx_graph_parse_util import parse_sole_graph_module
78from graph_net .tensor_meta import TensorMeta
@@ -37,8 +38,9 @@ def _make_model_runnable_predicator(self, config):
3738
3839 def _make_config (
3940 self ,
40- data_input_predicator_filepath ,
41- model_runnable_predicator_filepath ,
41+ resume : bool = False ,
42+ data_input_predicator_filepath = None ,
43+ model_runnable_predicator_filepath = None ,
4244 output_dir = "./tmp/graph_variable_renamer_dir" ,
4345 filter_path = None ,
4446 filter_config = None ,
@@ -59,6 +61,7 @@ def _make_config(
5961 if model_runnable_predicator_config is None :
6062 model_runnable_predicator_config = {}
6163 return {
64+ "resume" : resume ,
6265 "output_dir" : output_dir ,
6366 "filter_path" : filter_path ,
6467 "filter_config" : filter_config if filter_config is not None else {},
@@ -82,12 +85,25 @@ def __call__(self, rel_model_path):
8285 dst_model_path = os .path .realpath (
8386 os .path .join (self .config ["output_dir" ], rel_model_path )
8487 )
88+ if self .config ["resume" ] and os .path .exists (dst_model_path ):
89+ return
8590 Path (dst_model_path ).parent .mkdir (parents = True , exist_ok = True )
86- shutil .copytree (src_model_path , dst_model_path , dirs_exist_ok = True )
87- self ._update_model_py_file (gm , dst_model_path )
88- self ._update_weight_meta_py_file (src_model_path , dst_model_path )
89- self ._update_input_meta_py_file (src_model_path , dst_model_path )
90- self ._try_run (dst_model_path )
91+ temp_dir = tempfile .mkdtemp (prefix = "graph_variable_renamer_" )
92+ temp_model_path = os .path .join (temp_dir , os .path .basename (dst_model_path ))
93+ try :
94+ shutil .copytree (src_model_path , temp_model_path , dirs_exist_ok = True )
95+ self ._update_model_py_file (gm , temp_model_path )
96+ self ._update_weight_meta_py_file (src_model_path , temp_model_path )
97+ self ._update_input_meta_py_file (src_model_path , temp_model_path )
98+ print ("Try to run renamed model..." )
99+ self ._try_run (temp_model_path )
100+ if os .path .exists (dst_model_path ):
101+ shutil .rmtree (dst_model_path )
102+ shutil .copytree (temp_model_path , dst_model_path )
103+ except Exception as e :
104+ raise RuntimeError (f"Failed to handle { src_model_path } : { e } " )
105+ finally :
106+ shutil .rmtree (temp_dir , ignore_errors = True )
91107
92108 def _try_run (self , model_path ):
93109 assert self .model_runnable_predicator (
0 commit comments