Skip to content

Commit 55b6cc2

Browse files
committed
update sovits tensorboard info
1 parent e85e87d commit 55b6cc2

File tree

2 files changed

+21
-35
lines changed

2 files changed

+21
-35
lines changed

src/train/sovits.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import asdict, dataclass
22
import logging
3+
from pickle import GLOBAL
34
import traceback
45
from typing import Any, List, Tuple
56
import torch.distributed as dist
@@ -120,6 +121,9 @@ class TrainConfig(BaseModel):
120121
content_module: str
121122

122123

124+
GLOBAL_STEP = 0
125+
126+
123127
class SovitsTrain:
124128
def _update_hparams(self, hps: TrainConfig, params: SovitsTrainParams):
125129
hps.train.batch_size = params.batch_size
@@ -157,7 +161,6 @@ def __init__(self, params: SovitsTrainParams):
157161
hps = TrainConfig(**json_data)
158162
self.hps = self._update_hparams(hps, params)
159163
logger.info(f"train sovits with config: {self.hps}")
160-
self.step = 0
161164
self.device = "cpu"
162165

163166
warnings.filterwarnings("ignore")
@@ -207,6 +210,7 @@ def train(self):
207210
return TrainOutput(model_path=self.hps.train.output_dir)
208211

209212
def _run(self, rank, n_gpus, hps: TrainConfig):
213+
global GLOBAL_STEP
210214
if rank == 0:
211215
logger.info("hps for train sovits", hps)
212216
writer = SummaryWriter(log_dir=get_tensorboard_log_dir(hps.name))
@@ -333,11 +337,11 @@ def _run(self, rank, n_gpus, hps: TrainConfig):
333337
net_g,
334338
optim_g,
335339
)
336-
self.step = (epoch_str - 1) * len(train_loader)
340+
GLOBAL_STEP = (epoch_str - 1) * len(train_loader)
337341
except Exception as e:
338342
logger.warning(f"load failed, exception: {e}, use pretrained instead")
339343
epoch_str = 1
340-
step = 0
344+
GLOBAL_STEP = 0
341345
if hps.train.pretrained_s2G != "" and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
342346
if rank == 0:
343347
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
@@ -412,6 +416,7 @@ def _run(self, rank, n_gpus, hps: TrainConfig):
412416
def _train_and_evaluate(
413417
self, rank, epoch, hps: TrainConfig, nets, optims, schedulers, scaler, loaders, logger, writers
414418
):
419+
global GLOBAL_STEP
415420
connector = MultiProcessOutputConnector()
416421
device = self.device
417422
net_g, net_d = nets
@@ -523,19 +528,19 @@ def _train_and_evaluate(
523528
scaler.step(optim_g)
524529
scaler.update()
525530

526-
if self.step % 10 == 0:
531+
if GLOBAL_STEP % 10 == 0:
527532
connector.write_loss(
528-
self.step,
533+
GLOBAL_STEP,
529534
loss=convert_tensor_to_python(loss_gen_all),
530535
other={
531536
"loss/g/total": convert_tensor_to_python(loss_gen_all),
532537
"loss/d/total": convert_tensor_to_python(loss_disc_all),
533538
"learning_rate": convert_tensor_to_python(optim_g.param_groups[0]["lr"]),
534539
})
535-
logger.info(f"step: {self.step}, loss: {convert_tensor_to_python(loss_gen_all)}")
540+
logger.info(f"step: {GLOBAL_STEP}, loss: {convert_tensor_to_python(loss_gen_all)}")
536541

537542
if rank == 0:
538-
if self.step % hps.train.log_interval == 0:
543+
if GLOBAL_STEP % hps.train.log_interval == 0:
539544
lr = optim_g.param_groups[0]["lr"]
540545
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
541546
logger.info(
@@ -560,27 +565,12 @@ def _train_and_evaluate(
560565
}
561566
)
562567

563-
image_dict = {
564-
"slice/mel_org": helper.plot_spectrogram_to_numpy(
565-
y_mel[0].data.cpu().numpy()
566-
),
567-
"slice/mel_gen": helper.plot_spectrogram_to_numpy(
568-
y_hat_mel[0].data.cpu().numpy()
569-
),
570-
"all/mel": helper.plot_spectrogram_to_numpy(
571-
mel[0].data.cpu().numpy()
572-
),
573-
"all/stats_ssl": helper.plot_spectrogram_to_numpy(
574-
stats_ssl[0].data.cpu().numpy()
575-
),
576-
}
577568
helper.summarize(
578569
writer=writer, # pyright: ignore
579-
global_step=self.step,
580-
images=image_dict,
570+
global_step=GLOBAL_STEP,
581571
scalars=scalar_dict,
582572
)
583-
self.step += 1
573+
GLOBAL_STEP += 1
584574
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
585575
if not hps.train.if_save_latest:
586576
ckpt.save_checkpoint(
@@ -589,7 +579,7 @@ def _train_and_evaluate(
589579
hps.train.learning_rate,
590580
epoch,
591581
os.path.join(
592-
hps.train.train_logs_dir, f"G_{self.step}.pth"
582+
hps.train.train_logs_dir, f"G_{GLOBAL_STEP}.pth"
593583
),
594584
)
595585
ckpt.save_checkpoint(
@@ -598,7 +588,7 @@ def _train_and_evaluate(
598588
hps.train.learning_rate,
599589
epoch,
600590
os.path.join(
601-
hps.train.train_logs_dir, f"D_{self.step}.pth"
591+
hps.train.train_logs_dir, f"D_{GLOBAL_STEP}.pth"
602592
),
603593
)
604594
else:
@@ -627,9 +617,9 @@ def _train_and_evaluate(
627617
ckpts = net_g.state_dict()
628618
msg = self._save_epoch(
629619
ckpts,
630-
hps.name + f"_e{epoch}_s{self.step}",
620+
hps.name + f"_e{epoch}_s{GLOBAL_STEP}",
631621
epoch,
632-
self.step,
622+
GLOBAL_STEP,
633623
hps,
634624
)
635625
logger.info(f"saving ckpt {hps.name}_e{epoch}:{msg}")

src/utils/helper/__init__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def set_seed(seed: int):
4242
def random_choice():
4343
return ''.join(random.choices(alphabet, k=8))
4444

45+
4546
def load_json(file_path):
4647
with open(file_path, "r") as f:
4748
data = f.read()
@@ -123,18 +124,13 @@ def summarize(
123124
global_step,
124125
scalars={},
125126
histograms={},
126-
images={},
127-
audios={},
128-
audio_sampling_rate=22050,
129127
):
130128
for k, v in scalars.items():
131129
writer.add_scalar(k, v, global_step)
132130
for k, v in histograms.items():
133131
writer.add_histogram(k, v, global_step)
134-
for k, v in images.items():
135-
writer.add_image(k, v, global_step, dataformats="HWC")
136-
for k, v in audios.items():
137-
writer.add_audio(k, v, global_step, audio_sampling_rate)
132+
writer.flush()
133+
138134

139135
def convert_tensor_to_python(obj):
140136
if isinstance(obj, torch.Tensor):

0 commit comments

Comments
 (0)