Skip to content

Commit c4629b5

Browse files
committed
🤐 DurationPredictor, f0/energy Predictor layers should use dtype=tf.float32 to prevent mixed_precision nan.
1 parent 6c073fb commit c4629b5

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

tensorflow_tts/models/fastspeech.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def mish(x):
6565

6666
class TFEmbedding(tf.keras.layers.Embedding):
6767
"""Faster version of embedding."""
68+
6869
def __init__(self, *args, **kwargs):
6970
super().__init__(*args, **kwargs)
7071

@@ -226,13 +227,17 @@ def call(self, inputs, training=False):
226227
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
227228

228229
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
229-
dk = tf.cast(tf.shape(key_layer)[-1], attention_scores.dtype) # scale attention_scores
230+
dk = tf.cast(
231+
tf.shape(key_layer)[-1], attention_scores.dtype
232+
) # scale attention_scores
230233
attention_scores = attention_scores / tf.math.sqrt(dk)
231234

232235
if attention_mask is not None:
233236
# extended_attention_masks for self attention encoder.
234237
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
235-
extended_attention_mask = tf.cast(extended_attention_mask, attention_scores.dtype)
238+
extended_attention_mask = tf.cast(
239+
extended_attention_mask, attention_scores.dtype
240+
)
236241
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
237242
attention_scores = attention_scores + extended_attention_mask
238243

@@ -481,7 +486,9 @@ def call(self, inputs, training=False):
481486
hidden_states = self.project_compatible_decoder(hidden_states)
482487

483488
# calculate new hidden states.
484-
hidden_states += tf.cast(self.decoder_positional_embeddings(decoder_pos), hidden_states.dtype)
489+
hidden_states += tf.cast(
490+
self.decoder_positional_embeddings(decoder_pos), hidden_states.dtype
491+
)
485492

486493
if self.config.n_speakers > 1:
487494
speaker_embeddings = self.decoder_speaker_embeddings(speaker_ids)
@@ -580,7 +587,9 @@ def __init__(self, config, **kwargs):
580587
def call(self, inputs, training=False):
581588
"""Call logic."""
582589
encoder_hidden_states, attention_mask = inputs
583-
attention_mask = tf.cast(tf.expand_dims(attention_mask, 2), encoder_hidden_states.dtype)
590+
attention_mask = tf.cast(
591+
tf.expand_dims(attention_mask, 2), encoder_hidden_states.dtype
592+
)
584593

585594
# mask encoder hidden states
586595
masked_encoder_hidden_states = encoder_hidden_states * attention_mask
@@ -641,7 +650,9 @@ def _length_regulator(self, encoder_hidden_states, durations_gt):
641650
outputs = repeat_encoder_hidden_states
642651
encoder_masks = masks
643652
else:
644-
outputs = tf.zeros(shape=[0, max_durations, hidden_size], dtype=encoder_hidden_states.dtype)
653+
outputs = tf.zeros(
654+
shape=[0, max_durations, hidden_size], dtype=encoder_hidden_states.dtype
655+
)
645656
encoder_masks = tf.zeros(shape=[0, max_durations], dtype=tf.int32)
646657

647658
def condition(
@@ -732,7 +743,7 @@ def __init__(self, config, **kwargs):
732743
config.encoder_self_attention_params, name="encoder"
733744
)
734745
self.duration_predictor = TFFastSpeechDurationPredictor(
735-
config, name="duration_predictor"
746+
config, dtype=tf.float32, name="duration_predictor"
736747
)
737748
self.length_regulator = TFFastSpeechLengthRegulator(
738749
config,
@@ -745,8 +756,12 @@ def __init__(self, config, **kwargs):
745756
== config.decoder_self_attention_params.hidden_size,
746757
name="decoder",
747758
)
748-
self.mel_dense = tf.keras.layers.Dense(units=config.num_mels, dtype=tf.float32, name="mel_before")
749-
self.postnet = TFTacotronPostnet(config=config, dtype=tf.float32, name="postnet")
759+
self.mel_dense = tf.keras.layers.Dense(
760+
units=config.num_mels, dtype=tf.float32, name="mel_before"
761+
)
762+
self.postnet = TFTacotronPostnet(
763+
config=config, dtype=tf.float32, name="postnet"
764+
)
750765

751766
self.setup_inference_fn()
752767

tensorflow_tts/models/fastspeech2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def __init__(self, config, **kwargs):
6464
def call(self, inputs, training=False):
6565
"""Call logic."""
6666
encoder_hidden_states, speaker_ids, attention_mask = inputs
67-
attention_mask = tf.cast(tf.expand_dims(attention_mask, 2), encoder_hidden_states.dtype)
67+
attention_mask = tf.cast(
68+
tf.expand_dims(attention_mask, 2), encoder_hidden_states.dtype
69+
)
6870

6971
if self.config.n_speakers > 1:
7072
speaker_embeddings = self.decoder_speaker_embeddings(speaker_ids)
@@ -91,12 +93,14 @@ class TFFastSpeech2(TFFastSpeech):
9193
def __init__(self, config, **kwargs):
9294
"""Init layers for fastspeech."""
9395
super().__init__(config, **kwargs)
94-
self.f0_predictor = TFFastSpeechVariantPredictor(config, name="f0_predictor")
96+
self.f0_predictor = TFFastSpeechVariantPredictor(
97+
config, dtype=tf.float32, name="f0_predictor"
98+
)
9599
self.energy_predictor = TFFastSpeechVariantPredictor(
96-
config, name="energy_predictor",
100+
config, dtype=tf.float32, name="energy_predictor",
97101
)
98102
self.duration_predictor = TFFastSpeechVariantPredictor(
99-
config, name="duration_predictor"
103+
config, dtype=tf.float32, name="duration_predictor"
100104
)
101105

102106
# define f0_embeddings and energy_embeddings

0 commit comments

Comments
 (0)