Skip to content

Commit 51cbd2c

Browse files
authored
feat: Ensure that training resumes from latest checkpoint (#63)
* feat: Ensure that training resumes from latest checkpoint Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> * Fix lint issues Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> * Fix import order Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> --------- Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com>
1 parent 51b65fc commit 51cbd2c

File tree

3 files changed

+4
-5
lines changed

3 files changed

+4
-5
lines changed

dfm/src/automodel/recipes/train.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,8 @@ def setup(self):
306306
start_epoch=int(self.start_epoch),
307307
num_epochs=int(self.num_epochs),
308308
)
309-
# Optional resume only through config-defined restore_from
310-
if self.restore_from:
311-
self.load_checkpoint(restore_from=self.restore_from)
309+
310+
self.load_checkpoint(self.restore_from)
312311

313312
if is_main_process():
314313
os.makedirs(self.checkpoint_config.checkpoint_dir, exist_ok=True)

examples/automodel/finetune/finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from dfm.src.automodel.recipes.train import TrainWan21DiffusionRecipe
2020

2121

22-
def main(default_config_path="/opt/DFM/dfm/examples/Automodel/finetune/wan2_1_t2v_flow.yaml"):
22+
def main(default_config_path="examples/automodel/finetune/wan2_1_t2v_flow.yaml"):
2323
cfg = parse_args_and_load_config(default_config_path)
2424
recipe = TrainWan21DiffusionRecipe(cfg)
2525
recipe.setup()

examples/automodel/pretrain/pretrain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from dfm.src.automodel.recipes.train import TrainWan21DiffusionRecipe
2020

2121

22-
def main(default_config_path="/opt/DFM/dfm/examples/Automodel/pretrain/wan2_1_t2v_flow.yaml"):
22+
def main(default_config_path="examples/automodel/pretrain/wan2_1_t2v_flow.yaml"):
2323
cfg = parse_args_and_load_config(default_config_path)
2424
recipe = TrainWan21DiffusionRecipe(cfg)
2525
recipe.setup()

0 commit comments

Comments
 (0)