4444from tensorflow_tts .utils import (calculate_2d_loss , calculate_3d_loss ,
4545 return_strategy )
4646
47+ from tensorflow_tts .configs import ParallelWaveGANDiscriminatorConfig
48+
49+ from tensorflow_tts .models import TFParallelWaveGANDiscriminator
50+
4751
4852class MultiBandMelganTrainer (MelganTrainer ):
4953 """Multi-Band MelGAN Trainer class based on MelganTrainer."""
@@ -158,14 +162,13 @@ def compute_per_example_generator_losses(self, batch, outputs):
158162 p_hat = self ._discriminator (y_hat )
159163 p = self ._discriminator (tf .expand_dims (audios , 2 ))
160164 adv_loss = 0.0
161- for i in range (len (p_hat )):
162- adv_loss += calculate_3d_loss (
163- tf .ones_like (p_hat [i ][- 1 ]), p_hat [i ][- 1 ], loss_fn = self .mse_loss
164- )
165- adv_loss /= i + 1
165+ adv_loss += calculate_3d_loss (
166+ tf .ones_like (p_hat ), p_hat , loss_fn = self .mse_loss
167+ )
166168 gen_loss += self .config ["lambda_adv" ] * adv_loss
167169
168- dict_metrics_losses .update ({"adversarial_loss" : adv_loss },)
170+ # update dict_metrics_losses
171+ dict_metrics_losses .update ({"adversarial_loss" : adv_loss })
169172
170173 dict_metrics_losses .update ({"gen_loss" : gen_loss })
171174 dict_metrics_losses .update ({"subband_spectral_convergence_loss" : sub_sc_loss })
@@ -178,7 +181,9 @@ def compute_per_example_generator_losses(self, batch, outputs):
178181
179182 def compute_per_example_discriminator_losses (self , batch , gen_outputs ):
180183 audios = batch ["audios" ]
181- y_hat = gen_outputs
184+ y_mb_hat = gen_outputs
185+
186+ y_hat = self .pqmf .synthesis (y_mb_hat )
182187
183188 y = tf .expand_dims (audios , 2 )
184189 p = self ._discriminator (y )
0 commit comments