Skip to content

Commit 518ee49

Browse files
committed
🐛 Output of melgan/mb-melgan should be float32 to prevent nan in mixed_precision.
1 parent d1789e3 commit 518ee49

File tree

4 files changed

+53
-30
lines changed

4 files changed

+53
-30
lines changed

examples/multiband_melgan/train_multiband_melgan.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,17 @@
3636
import tensorflow_tts
3737
from examples.melgan.audio_mel_dataset import AudioMelDataset
3838
from examples.melgan.train_melgan import MelganTrainer, collater
39-
from tensorflow_tts.configs import (MultiBandMelGANDiscriminatorConfig,
40-
MultiBandMelGANGeneratorConfig)
39+
from tensorflow_tts.configs import (
40+
MultiBandMelGANDiscriminatorConfig,
41+
MultiBandMelGANGeneratorConfig,
42+
)
4143
from tensorflow_tts.losses import TFMultiResolutionSTFT
42-
from tensorflow_tts.models import (TFPQMF, TFMelGANGenerator,
43-
TFMelGANMultiScaleDiscriminator)
44-
from tensorflow_tts.utils import (calculate_2d_loss, calculate_3d_loss,
45-
return_strategy)
44+
from tensorflow_tts.models import (
45+
TFPQMF,
46+
TFMelGANGenerator,
47+
TFMelGANMultiScaleDiscriminator,
48+
)
49+
from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy
4650

4751

4852
class MultiBandMelganTrainer(MelganTrainer):
@@ -313,7 +317,7 @@ def main():
313317
default="",
314318
type=str,
315319
nargs="?",
316-
help='path of .h5 mb-melgan generator to load weights from',
320+
help="path of .h5 mb-melgan generator to load weights from",
317321
)
318322
args = parser.parse_args()
319323

@@ -438,28 +442,38 @@ def main():
438442
with STRATEGY.scope():
439443
# define generator and discriminator
440444
generator = TFMelGANGenerator(
441-
MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator_params"]),
445+
MultiBandMelGANGeneratorConfig(
446+
**config["multiband_melgan_generator_params"]
447+
),
442448
name="multi_band_melgan_generator",
443449
)
444450

445451
discriminator = TFMelGANMultiScaleDiscriminator(
446-
MultiBandMelGANDiscriminatorConfig(**config["multiband_melgan_discriminator_params"]),
452+
MultiBandMelGANDiscriminatorConfig(
453+
**config["multiband_melgan_discriminator_params"]
454+
),
447455
name="multi_band_melgan_discriminator",
448456
)
449457

450458
pqmf = TFPQMF(
451-
MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator_params"]), name="pqmf"
459+
MultiBandMelGANGeneratorConfig(
460+
**config["multiband_melgan_generator_params"]
461+
),
462+
dtype=tf.float32,
463+
name="pqmf",
452464
)
453465

454466
# dummy input to build model.
455467
fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)
456468
y_mb_hat = generator(fake_mels)
457469
y_hat = pqmf.synthesis(y_mb_hat)
458470
discriminator(y_hat)
459-
471+
460472
if len(args.pretrained) > 1:
461473
generator.load_weights(args.pretrained)
462-
logging.info(f"Successfully loaded pretrained weight from {args.pretrained}.")
474+
logging.info(
475+
f"Successfully loaded pretrained weight from {args.pretrained}."
476+
)
463477

464478
generator.summary()
465479
discriminator.summary()

examples/multiband_pwgan/train_multiband_pwgan.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,17 @@
3636
import tensorflow_tts
3737
from examples.melgan.audio_mel_dataset import AudioMelDataset
3838
from examples.melgan.train_melgan import MelganTrainer, collater
39-
from tensorflow_tts.configs import (MultiBandMelGANDiscriminatorConfig,
40-
MultiBandMelGANGeneratorConfig)
39+
from tensorflow_tts.configs import (
40+
MultiBandMelGANDiscriminatorConfig,
41+
MultiBandMelGANGeneratorConfig,
42+
)
4143
from tensorflow_tts.losses import TFMultiResolutionSTFT
42-
from tensorflow_tts.models import (TFPQMF, TFMelGANGenerator,
43-
TFMelGANMultiScaleDiscriminator)
44-
from tensorflow_tts.utils import (calculate_2d_loss, calculate_3d_loss,
45-
return_strategy)
44+
from tensorflow_tts.models import (
45+
TFPQMF,
46+
TFMelGANGenerator,
47+
TFMelGANMultiScaleDiscriminator,
48+
)
49+
from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy
4650

4751
from tensorflow_tts.configs import ParallelWaveGANDiscriminatorConfig
4852

@@ -327,7 +331,7 @@ def main():
327331
default="",
328332
type=str,
329333
nargs="?",
330-
help='path of .h5 mb-melgan generator to load weights from',
334+
help="path of .h5 mb-melgan generator to load weights from",
331335
)
332336
args = parser.parse_args()
333337

@@ -452,7 +456,9 @@ def main():
452456
with STRATEGY.scope():
453457
# define generator and discriminator
454458
generator = TFMelGANGenerator(
455-
MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator_params"]),
459+
MultiBandMelGANGeneratorConfig(
460+
**config["multiband_melgan_generator_params"]
461+
),
456462
name="multi_band_melgan_generator",
457463
)
458464

@@ -464,19 +470,24 @@ def main():
464470
)
465471

466472
pqmf = TFPQMF(
467-
MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator_params"]), name="pqmf"
473+
MultiBandMelGANGeneratorConfig(
474+
**config["multiband_melgan_generator_params"]
475+
),
476+
dtype=tf.float32,
477+
name="pqmf",
468478
)
469479

470480
# dummy input to build model.
471481
fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)
472482
y_mb_hat = generator(fake_mels)
473483
y_hat = pqmf.synthesis(y_mb_hat)
474484
discriminator(y_hat)
475-
485+
476486
if len(args.pretrained) > 1:
477487
generator.load_weights(args.pretrained)
478-
logging.info(f"Successfully loaded pretrained weight from {args.pretrained}.")
479-
488+
logging.info(
489+
f"Successfully loaded pretrained weight from {args.pretrained}."
490+
)
480491

481492
generator.summary()
482493
discriminator.summary()
@@ -494,10 +505,7 @@ def main():
494505
learning_rate=generator_lr_fn,
495506
amsgrad=config["generator_optimizer_params"]["amsgrad"],
496507
)
497-
dis_optimizer = RectifiedAdam(
498-
learning_rate=discriminator_lr_fn, amsgrad=False
499-
)
500-
508+
dis_optimizer = RectifiedAdam(learning_rate=discriminator_lr_fn, amsgrad=False)
501509

502510
trainer.compile(
503511
gen_model=generator,

tensorflow_tts/models/mb_melgan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ class TFMBMelGANGenerator(TFMelGANGenerator):
161161

162162
def __init__(self, config, **kwargs):
163163
super().__init__(config, **kwargs)
164-
self.pqmf = TFPQMF(config=config, name="pqmf")
164+
self.pqmf = TFPQMF(config=config, dtype=tf.float32, name="pqmf")
165165

166166
def call(self, mels, **kwargs):
167167
"""Calculate forward propagation.

tensorflow_tts/models/melgan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,11 @@ def __init__(self, config, **kwargs):
263263
kernel_size=config.kernel_size,
264264
use_bias=config.use_bias,
265265
kernel_initializer=get_initializer(config.initializer_seed),
266+
dtype=tf.float32,
266267
),
267268
]
268269
if config.use_final_nolinear_activation:
269-
layers += [tf.keras.layers.Activation("tanh")]
270+
layers += [tf.keras.layers.Activation("tanh", dtype=tf.float32)]
270271

271272
if config.is_weight_norm is True:
272273
self._apply_weightnorm(layers)

0 commit comments

Comments
 (0)