Skip to content

Commit 0bf228e

Browse files
committed
update save path
1 parent c2a1f7c commit 0bf228e

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

src/train/sovits.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,11 @@ def _run(self, rank, n_gpus, hps: TrainConfig):
330330
net_g,
331331
optim_g,
332332
)
333-
global_step = (epoch_str - 1) * len(train_loader)
333+
self.step = (epoch_str - 1) * len(train_loader)
334334
except Exception as e:
335335
logger.warning(f"load failed, exception: {e}, use pretrained instead")
336336
epoch_str = 1
337-
global_step = 0
337+
step = 0
338338
if hps.train.pretrained_s2G != "" and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
339339
if rank == 0:
340340
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
@@ -570,7 +570,7 @@ def _train_and_evaluate(
570570
hps.train.learning_rate,
571571
epoch,
572572
os.path.join(
573-
hps.train.train_logs_dir, f"sovits_G_epoch{epoch}_step{self.step}.pth"
573+
hps.train.train_logs_dir, f"G_{self.step}.pth"
574574
),
575575
)
576576
ckpt.save_checkpoint(
@@ -579,7 +579,7 @@ def _train_and_evaluate(
579579
hps.train.learning_rate,
580580
epoch,
581581
os.path.join(
582-
hps.train.train_logs_dir, f"sovits_D_epoch{epoch}_step{self.step}.pth"
582+
hps.train.train_logs_dir, f"D_{self.step}.pth"
583583
),
584584
)
585585
else:
@@ -589,7 +589,7 @@ def _train_and_evaluate(
589589
hps.train.learning_rate,
590590
epoch,
591591
os.path.join(
592-
hps.train.train_logs_dir, "sovits_G_latest.pth"
592+
hps.train.train_logs_dir, "G_latest.pth"
593593
),
594594
)
595595
ckpt.save_checkpoint(
@@ -598,7 +598,7 @@ def _train_and_evaluate(
598598
hps.train.learning_rate,
599599
epoch,
600600
os.path.join(
601-
hps.train.train_logs_dir, "sovits_D_latest.pth"
601+
hps.train.train_logs_dir, "D_latest.pth"
602602
),
603603
)
604604
if rank == 0 and hps.train.if_save_every_weights == True:

src/utils/path/ckpt.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
5656

5757
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
5858
f_list = glob.glob(os.path.join(dir_path, regex))
59+
latest = list(filter(lambda x: x.endswith("latest.pth"), f_list))
60+
if latest:
61+
logger.info(f"latest checkpoint in dir {dir_path} is: {latest[0]}")
62+
return latest[0]
63+
5964
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
6065
path = f_list[-1]
6166
logger.info(f"latest checkpoint in dir {dir_path} is: {path}")

0 commit comments

Comments
 (0)