Skip to content

Commit dc0a176

Browse files
committed
Fix resume in graph_variable_renamer to early exit
1 parent 47ab6ab commit dc0a176

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

graph_net/torch/graph_variable_renamer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,21 @@ def _make_config(
8080
}
8181

8282
def __call__(self, rel_model_path):
83-
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
84-
with cuda_gc(enabled=self.config["release_gpu_memory"]):
85-
module, inputs = get_torch_module_and_inputs(src_model_path)
86-
gm = parse_sole_graph_module(module, inputs)
87-
gm = self.rename_graph_variables(gm, inputs, src_model_path)
88-
del module, inputs
8983
dst_model_path = os.path.realpath(
9084
os.path.join(self.config["output_dir"], rel_model_path)
9185
)
9286
if self.config["resume"] and os.path.exists(
9387
os.path.join(dst_model_path, "model.py")
9488
):
9589
return
90+
91+
src_model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
92+
with cuda_gc(enabled=self.config["release_gpu_memory"]):
93+
module, inputs = get_torch_module_and_inputs(src_model_path)
94+
gm = parse_sole_graph_module(module, inputs)
95+
gm = self.rename_graph_variables(gm, inputs, src_model_path)
96+
del module, inputs
97+
9698
Path(dst_model_path).parent.mkdir(parents=True, exist_ok=True)
9799
with tempfile.TemporaryDirectory(prefix="graph_variable_renamer_") as temp_dir:
98100
temp_model_path = os.path.join(temp_dir, os.path.basename(dst_model_path))

0 commit comments

Comments
 (0)