@@ -80,19 +80,21 @@ def _make_config(
8080 }
8181
8282 def __call__ (self , rel_model_path ):
83- src_model_path = os .path .join (self .config ["model_path_prefix" ], rel_model_path )
84- with cuda_gc (enabled = self .config ["release_gpu_memory" ]):
85- module , inputs = get_torch_module_and_inputs (src_model_path )
86- gm = parse_sole_graph_module (module , inputs )
87- gm = self .rename_graph_variables (gm , inputs , src_model_path )
88- del module , inputs
8983 dst_model_path = os .path .realpath (
9084 os .path .join (self .config ["output_dir" ], rel_model_path )
9185 )
9286 if self .config ["resume" ] and os .path .exists (
9387 os .path .join (dst_model_path , "model.py" )
9488 ):
9589 return
90+
91+ src_model_path = os .path .join (self .config ["model_path_prefix" ], rel_model_path )
92+ with cuda_gc (enabled = self .config ["release_gpu_memory" ]):
93+ module , inputs = get_torch_module_and_inputs (src_model_path )
94+ gm = parse_sole_graph_module (module , inputs )
95+ gm = self .rename_graph_variables (gm , inputs , src_model_path )
96+ del module , inputs
97+
9698 Path (dst_model_path ).parent .mkdir (parents = True , exist_ok = True )
9799 with tempfile .TemporaryDirectory (prefix = "graph_variable_renamer_" ) as temp_dir :
98100 temp_model_path = os .path .join (temp_dir , os .path .basename (dst_model_path ))
0 commit comments