3333import json
3434
3535import tensorflow_tts
36- from examples .fastspeech2_libritts .fastspeech2_dataset import \
37- CharactorDurationF0EnergyMelDataset
36+ from examples .fastspeech2_libritts .fastspeech2_dataset import (
37+ CharactorDurationF0EnergyMelDataset ,
38+ )
3839from tensorflow_tts .configs import FastSpeech2Config
3940from tensorflow_tts .models import TFFastSpeech2
4041from tensorflow_tts .optimizers import AdamWeightDecay , WarmUp
4142from tensorflow_tts .trainers import Seq2SeqBasedTrainer
42- from tensorflow_tts .utils import (calculate_2d_loss , calculate_3d_loss ,
43- return_strategy , TFGriffinLim )
43+ from tensorflow_tts .utils import (
44+ calculate_2d_loss ,
45+ calculate_3d_loss ,
46+ return_strategy ,
47+ TFGriffinLim ,
48+ )
4449
4550
4651class FastSpeech2Trainer (Seq2SeqBasedTrainer ):
4752 """FastSpeech2 Trainer class based on FastSpeechTrainer."""
4853
4954 def __init__ (
50- self , config , strategy , steps = 0 , epochs = 0 , is_mixed_precision = False , stats_path : str = "" ,
51- dataset_config : str = ""
55+ self ,
56+ config ,
57+ strategy ,
58+ steps = 0 ,
59+ epochs = 0 ,
60+ is_mixed_precision = False ,
61+ stats_path : str = "" ,
62+ dataset_config : str = "" ,
5263 ):
5364 """Initialize trainer.
5465 Args:
@@ -78,7 +89,9 @@ def __init__(
7889 self .use_griffin = config .get ("use_griffin" , False )
7990 self .griffin_lim_tf = None
8091 if self .use_griffin :
81- logging .info (f"Load griff stats from { stats_path } and config from { dataset_config } " )
92+ logging .info (
93+ f"Load griff stats from { stats_path } and config from { dataset_config } "
94+ )
8295 self .griff_conf = yaml .load (open (dataset_config ), Loader = yaml .Loader )
8396 self .prepare_grim (stats_path , self .griff_conf )
8497
@@ -160,7 +173,9 @@ def generate_and_save_intermediate_result(self, batch):
160173
161174 # check directory
162175 if self .use_griffin :
163- griff_dir_name = os .path .join (self .config ["outdir" ], f"predictions/{ self .steps } _wav" )
176+ griff_dir_name = os .path .join (
177+ self .config ["outdir" ], f"predictions/{ self .steps } _wav"
178+ )
164179 if not os .path .exists (griff_dir_name ):
165180 os .makedirs (griff_dir_name )
166181
@@ -171,23 +186,31 @@ def generate_and_save_intermediate_result(self, batch):
171186 for idx , (mel_gt , mel_before , mel_after ) in enumerate (
172187 zip (mel_gts , mels_before , mels_after ), 0
173188 ):
174-
175-
189+
176190 if self .use_griffin :
177191 utt_id = utt_ids [idx ]
178- grif_before = self .griffin_lim_tf (tf .reshape (mel_before , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32 )
179- grif_after = self .griffin_lim_tf (tf .reshape (mel_after , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32 )
180- grif_gt = self .griffin_lim_tf (tf .reshape (mel_gt , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32 )
181- self .griffin_lim_tf .save_wav (grif_before , griff_dir_name , f"{ utt_id } _before" )
182- self .griffin_lim_tf .save_wav (grif_after , griff_dir_name , f"{ utt_id } _after" )
192+ grif_before = self .griffin_lim_tf (
193+ tf .reshape (mel_before , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32
194+ )
195+ grif_after = self .griffin_lim_tf (
196+ tf .reshape (mel_after , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32
197+ )
198+ grif_gt = self .griffin_lim_tf (
199+ tf .reshape (mel_gt , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32
200+ )
201+ self .griffin_lim_tf .save_wav (
202+ grif_before , griff_dir_name , f"{ utt_id } _before"
203+ )
204+ self .griffin_lim_tf .save_wav (
205+ grif_after , griff_dir_name , f"{ utt_id } _after"
206+ )
183207 self .griffin_lim_tf .save_wav (grif_gt , griff_dir_name , f"{ utt_id } _gt" )
184-
208+
185209 utt_id = utt_ids [idx ]
186210 mel_gt = tf .reshape (mel_gt , (- 1 , 80 )).numpy () # [length, 80]
187211 mel_before = tf .reshape (mel_before , (- 1 , 80 )).numpy () # [length, 80]
188212 mel_after = tf .reshape (mel_after , (- 1 , 80 )).numpy () # [length, 80]
189213
190-
191214 # plit figure and save it
192215 figname = os .path .join (dirname , f"{ utt_id } .png" )
193216 fig = plt .figure (figsize = (10 , 8 ))
@@ -229,10 +252,7 @@ def main():
229252 "--use-norm" , default = 1 , type = int , help = "usr norm-mels for train or raw."
230253 )
231254 parser .add_argument (
232- "--f0-stat" ,
233- default = "./dump/stats_f0.npy" ,
234- type = str ,
235- help = "f0-stat path." ,
255+ "--f0-stat" , default = "./dump/stats_f0.npy" , type = str , help = "f0-stat path." ,
236256 )
237257 parser .add_argument (
238258 "--energy-stat" ,
@@ -266,26 +286,20 @@ def main():
266286 help = "using mixed precision for generator or not." ,
267287 )
268288 parser .add_argument (
269- "--dataset_config" ,
270- default = "preprocess/libritts_preprocess.yaml" ,
271- type = str ,
289+ "--dataset_config" , default = "preprocess/libritts_preprocess.yaml" , type = str ,
272290 )
273291 parser .add_argument (
274- "--dataset_stats" ,
275- default = "dump/stats.npy" ,
276- type = str ,
292+ "--dataset_stats" , default = "dump/stats.npy" , type = str ,
277293 )
278294 parser .add_argument (
279- "--dataset_mapping" ,
280- default = "dump/libritts_mapper.npy" ,
281- type = str ,
295+ "--dataset_mapping" , default = "dump/libritts_mapper.npy" , type = str ,
282296 )
283297 parser .add_argument (
284298 "--pretrained" ,
285299 default = "" ,
286300 type = str ,
287301 nargs = "?" ,
288- help = ' pretrained weights .h5 file to load weights from. Auto-skips non-matching layers' ,
302+ help = " pretrained weights .h5 file to load weights from. Auto-skips non-matching layers" ,
289303 )
290304 args = parser .parse_args ()
291305
@@ -362,7 +376,9 @@ def main():
362376
363377 # Check n_speakers matches number of speakers in speakers_map
364378 n_speakers = config ["fastspeech2_params" ]["n_speakers" ]
365- assert n_speakers == len (speakers_map ), f"Number of speakers in dataset does not match n_speakers in config"
379+ assert n_speakers == len (
380+ speakers_map
381+ ), f"Number of speakers in dataset does not match n_speakers in config"
366382
367383 # define train/valid dataset
368384 train_dataset = CharactorDurationF0EnergyMelDataset (
@@ -375,11 +391,13 @@ def main():
375391 f0_stat = args .f0_stat ,
376392 energy_stat = args .energy_stat ,
377393 mel_length_threshold = mel_length_threshold ,
378- speakers_map = speakers_map
394+ speakers_map = speakers_map ,
379395 ).create (
380396 is_shuffle = config ["is_shuffle" ],
381397 allow_cache = config ["allow_cache" ],
382- batch_size = config ["batch_size" ] * STRATEGY .num_replicas_in_sync ,
398+ batch_size = config ["batch_size" ]
399+ * STRATEGY .num_replicas_in_sync
400+ * config ["gradient_accumulation_steps" ],
383401 )
384402
385403 valid_dataset = CharactorDurationF0EnergyMelDataset (
@@ -392,7 +410,7 @@ def main():
392410 f0_stat = args .f0_stat ,
393411 energy_stat = args .energy_stat ,
394412 mel_length_threshold = mel_length_threshold ,
395- speakers_map = speakers_map
413+ speakers_map = speakers_map ,
396414 ).create (
397415 is_shuffle = config ["is_shuffle" ],
398416 allow_cache = config ["allow_cache" ],
@@ -407,7 +425,7 @@ def main():
407425 epochs = 0 ,
408426 is_mixed_precision = args .mixed_precision ,
409427 stats_path = args .dataset_stats ,
410- dataset_config = args .dataset_config
428+ dataset_config = args .dataset_config ,
411429 )
412430
413431 with STRATEGY .scope ():
@@ -417,11 +435,12 @@ def main():
417435 )
418436 fastspeech ._build ()
419437 fastspeech .summary ()
420-
438+
421439 if len (args .pretrained ) > 1 :
422440 fastspeech .load_weights (args .pretrained , by_name = True , skip_mismatch = True )
423- logging .info (f"Successfully loaded pretrained weight from { args .pretrained } ." )
424-
441+ logging .info (
442+ f"Successfully loaded pretrained weight from { args .pretrained } ."
443+ )
425444
426445 # AdamW for fastspeech
427446 learning_rate_fn = tf .keras .optimizers .schedules .PolynomialDecay (
0 commit comments