11from dataclasses import asdict , dataclass
22import logging
3+ from pickle import GLOBAL
34import traceback
45from typing import Any , List , Tuple
56import torch .distributed as dist
@@ -120,6 +121,9 @@ class TrainConfig(BaseModel):
120121 content_module : str
121122
122123
124+ GLOBAL_STEP = 0
125+
126+
123127class SovitsTrain :
124128 def _update_hparams (self , hps : TrainConfig , params : SovitsTrainParams ):
125129 hps .train .batch_size = params .batch_size
@@ -157,7 +161,6 @@ def __init__(self, params: SovitsTrainParams):
157161 hps = TrainConfig (** json_data )
158162 self .hps = self ._update_hparams (hps , params )
159163 logger .info (f"train sovits with config: { self .hps } " )
160- self .step = 0
161164 self .device = "cpu"
162165
163166 warnings .filterwarnings ("ignore" )
@@ -207,6 +210,7 @@ def train(self):
207210 return TrainOutput (model_path = self .hps .train .output_dir )
208211
209212 def _run (self , rank , n_gpus , hps : TrainConfig ):
213+ global GLOBAL_STEP
210214 if rank == 0 :
211215 logger .info ("hps for train sovits" , hps )
212216 writer = SummaryWriter (log_dir = get_tensorboard_log_dir (hps .name ))
@@ -333,11 +337,11 @@ def _run(self, rank, n_gpus, hps: TrainConfig):
333337 net_g ,
334338 optim_g ,
335339 )
336- self . step = (epoch_str - 1 ) * len (train_loader )
340+ GLOBAL_STEP = (epoch_str - 1 ) * len (train_loader )
337341 except Exception as e :
338342 logger .warning (f"load failed, exception: { e } , use pretrained instead" )
339343 epoch_str = 1
340- step = 0
344+ GLOBAL_STEP = 0
341345 if hps .train .pretrained_s2G != "" and hps .train .pretrained_s2G != None and os .path .exists (hps .train .pretrained_s2G ):
342346 if rank == 0 :
343347 logger .info ("loaded pretrained %s" % hps .train .pretrained_s2G )
@@ -412,6 +416,7 @@ def _run(self, rank, n_gpus, hps: TrainConfig):
412416 def _train_and_evaluate (
413417 self , rank , epoch , hps : TrainConfig , nets , optims , schedulers , scaler , loaders , logger , writers
414418 ):
419+ global GLOBAL_STEP
415420 connector = MultiProcessOutputConnector ()
416421 device = self .device
417422 net_g , net_d = nets
@@ -523,19 +528,19 @@ def _train_and_evaluate(
523528 scaler .step (optim_g )
524529 scaler .update ()
525530
526- if self . step % 10 == 0 :
531+ if GLOBAL_STEP % 10 == 0 :
527532 connector .write_loss (
528- self . step ,
533+ GLOBAL_STEP ,
529534 loss = convert_tensor_to_python (loss_gen_all ),
530535 other = {
531536 "loss/g/total" : convert_tensor_to_python (loss_gen_all ),
532537 "loss/d/total" : convert_tensor_to_python (loss_disc_all ),
533538 "learning_rate" : convert_tensor_to_python (optim_g .param_groups [0 ]["lr" ]),
534539 })
535- logger .info (f"step: { self . step } , loss: { convert_tensor_to_python (loss_gen_all )} " )
540+ logger .info (f"step: { GLOBAL_STEP } , loss: { convert_tensor_to_python (loss_gen_all )} " )
536541
537542 if rank == 0 :
538- if self . step % hps .train .log_interval == 0 :
543+ if GLOBAL_STEP % hps .train .log_interval == 0 :
539544 lr = optim_g .param_groups [0 ]["lr" ]
540545 losses = [loss_disc , loss_gen , loss_fm , loss_mel , kl_ssl , loss_kl ]
541546 logger .info (
@@ -560,27 +565,12 @@ def _train_and_evaluate(
560565 }
561566 )
562567
563- image_dict = {
564- "slice/mel_org" : helper .plot_spectrogram_to_numpy (
565- y_mel [0 ].data .cpu ().numpy ()
566- ),
567- "slice/mel_gen" : helper .plot_spectrogram_to_numpy (
568- y_hat_mel [0 ].data .cpu ().numpy ()
569- ),
570- "all/mel" : helper .plot_spectrogram_to_numpy (
571- mel [0 ].data .cpu ().numpy ()
572- ),
573- "all/stats_ssl" : helper .plot_spectrogram_to_numpy (
574- stats_ssl [0 ].data .cpu ().numpy ()
575- ),
576- }
577568 helper .summarize (
578569 writer = writer , # pyright: ignore
579- global_step = self .step ,
580- images = image_dict ,
570+ global_step = GLOBAL_STEP ,
581571 scalars = scalar_dict ,
582572 )
583- self . step += 1
573+ GLOBAL_STEP += 1
584574 if epoch % hps .train .save_every_epoch == 0 and rank == 0 :
585575 if not hps .train .if_save_latest :
586576 ckpt .save_checkpoint (
@@ -589,7 +579,7 @@ def _train_and_evaluate(
589579 hps .train .learning_rate ,
590580 epoch ,
591581 os .path .join (
592- hps .train .train_logs_dir , f"G_{ self . step } .pth"
582+ hps .train .train_logs_dir , f"G_{ GLOBAL_STEP } .pth"
593583 ),
594584 )
595585 ckpt .save_checkpoint (
@@ -598,7 +588,7 @@ def _train_and_evaluate(
598588 hps .train .learning_rate ,
599589 epoch ,
600590 os .path .join (
601- hps .train .train_logs_dir , f"D_{ self . step } .pth"
591+ hps .train .train_logs_dir , f"D_{ GLOBAL_STEP } .pth"
602592 ),
603593 )
604594 else :
@@ -627,9 +617,9 @@ def _train_and_evaluate(
627617 ckpts = net_g .state_dict ()
628618 msg = self ._save_epoch (
629619 ckpts ,
630- hps .name + f"_e{ epoch } _s{ self . step } " ,
620+ hps .name + f"_e{ epoch } _s{ GLOBAL_STEP } " ,
631621 epoch ,
632- self . step ,
622+ GLOBAL_STEP ,
633623 hps ,
634624 )
635625 logger .info (f"saving ckpt { hps .name } _e{ epoch } :{ msg } " )
0 commit comments