Skip to content

Commit 3bc5eb0

Browse files
committed
Fix errors
1 parent 833add5 commit 3bc5eb0

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

examples/multiband_pwgan/train_multiband_pwgan.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@
4444
from tensorflow_tts.utils import (calculate_2d_loss, calculate_3d_loss,
4545
return_strategy)
4646

47+
from tensorflow_tts.configs import ParallelWaveGANDiscriminatorConfig
48+
49+
from tensorflow_tts.models import TFParallelWaveGANDiscriminator
50+
4751

4852
class MultiBandMelganTrainer(MelganTrainer):
4953
"""Multi-Band MelGAN Trainer class based on MelganTrainer."""
@@ -158,14 +162,13 @@ def compute_per_example_generator_losses(self, batch, outputs):
158162
p_hat = self._discriminator(y_hat)
159163
p = self._discriminator(tf.expand_dims(audios, 2))
160164
adv_loss = 0.0
161-
for i in range(len(p_hat)):
162-
adv_loss += calculate_3d_loss(
163-
tf.ones_like(p_hat[i][-1]), p_hat[i][-1], loss_fn=self.mse_loss
164-
)
165-
adv_loss /= i + 1
165+
adv_loss += calculate_3d_loss(
166+
tf.ones_like(p_hat), p_hat, loss_fn=self.mse_loss
167+
)
166168
gen_loss += self.config["lambda_adv"] * adv_loss
167169

168-
dict_metrics_losses.update({"adversarial_loss": adv_loss},)
170+
# update dict_metrics_losses
171+
dict_metrics_losses.update({"adversarial_loss": adv_loss})
169172

170173
dict_metrics_losses.update({"gen_loss": gen_loss})
171174
dict_metrics_losses.update({"subband_spectral_convergence_loss": sub_sc_loss})
@@ -178,7 +181,9 @@ def compute_per_example_generator_losses(self, batch, outputs):
178181

179182
def compute_per_example_discriminator_losses(self, batch, gen_outputs):
180183
audios = batch["audios"]
181-
y_hat = gen_outputs
184+
y_mb_hat = gen_outputs
185+
186+
y_hat = self.pqmf.synthesis(y_mb_hat)
182187

183188
y = tf.expand_dims(audios, 2)
184189
p = self._discriminator(y)

0 commit comments

Comments
 (0)