@@ -65,6 +65,7 @@ def mish(x):
6565
6666class TFEmbedding (tf .keras .layers .Embedding ):
6767 """Faster version of embedding."""
68+
6869 def __init__ (self , * args , ** kwargs ):
6970 super ().__init__ (* args , ** kwargs )
7071
@@ -226,13 +227,17 @@ def call(self, inputs, training=False):
226227 value_layer = self .transpose_for_scores (mixed_value_layer , batch_size )
227228
228229 attention_scores = tf .matmul (query_layer , key_layer , transpose_b = True )
229- dk = tf .cast (tf .shape (key_layer )[- 1 ], attention_scores .dtype ) # scale attention_scores
230+ dk = tf .cast (
231+ tf .shape (key_layer )[- 1 ], attention_scores .dtype
232+ ) # scale attention_scores
230233 attention_scores = attention_scores / tf .math .sqrt (dk )
231234
232235 if attention_mask is not None :
233236 # extended_attention_masks for self attention encoder.
234237 extended_attention_mask = attention_mask [:, tf .newaxis , tf .newaxis , :]
235- extended_attention_mask = tf .cast (extended_attention_mask , attention_scores .dtype )
238+ extended_attention_mask = tf .cast (
239+ extended_attention_mask , attention_scores .dtype
240+ )
236241 extended_attention_mask = (1.0 - extended_attention_mask ) * - 1e9
237242 attention_scores = attention_scores + extended_attention_mask
238243
@@ -481,7 +486,9 @@ def call(self, inputs, training=False):
481486 hidden_states = self .project_compatible_decoder (hidden_states )
482487
483488 # calculate new hidden states.
484- hidden_states += tf .cast (self .decoder_positional_embeddings (decoder_pos ), hidden_states .dtype )
489+ hidden_states += tf .cast (
490+ self .decoder_positional_embeddings (decoder_pos ), hidden_states .dtype
491+ )
485492
486493 if self .config .n_speakers > 1 :
487494 speaker_embeddings = self .decoder_speaker_embeddings (speaker_ids )
@@ -580,7 +587,9 @@ def __init__(self, config, **kwargs):
580587 def call (self , inputs , training = False ):
581588 """Call logic."""
582589 encoder_hidden_states , attention_mask = inputs
583- attention_mask = tf .cast (tf .expand_dims (attention_mask , 2 ), encoder_hidden_states .dtype )
590+ attention_mask = tf .cast (
591+ tf .expand_dims (attention_mask , 2 ), encoder_hidden_states .dtype
592+ )
584593
585594 # mask encoder hidden states
586595 masked_encoder_hidden_states = encoder_hidden_states * attention_mask
@@ -641,7 +650,9 @@ def _length_regulator(self, encoder_hidden_states, durations_gt):
641650 outputs = repeat_encoder_hidden_states
642651 encoder_masks = masks
643652 else :
644- outputs = tf .zeros (shape = [0 , max_durations , hidden_size ], dtype = encoder_hidden_states .dtype )
653+ outputs = tf .zeros (
654+ shape = [0 , max_durations , hidden_size ], dtype = encoder_hidden_states .dtype
655+ )
645656 encoder_masks = tf .zeros (shape = [0 , max_durations ], dtype = tf .int32 )
646657
647658 def condition (
@@ -732,7 +743,7 @@ def __init__(self, config, **kwargs):
732743 config .encoder_self_attention_params , name = "encoder"
733744 )
734745 self .duration_predictor = TFFastSpeechDurationPredictor (
735- config , name = "duration_predictor"
746+ config , dtype = tf . float32 , name = "duration_predictor"
736747 )
737748 self .length_regulator = TFFastSpeechLengthRegulator (
738749 config ,
@@ -745,8 +756,12 @@ def __init__(self, config, **kwargs):
745756 == config .decoder_self_attention_params .hidden_size ,
746757 name = "decoder" ,
747758 )
748- self .mel_dense = tf .keras .layers .Dense (units = config .num_mels , dtype = tf .float32 , name = "mel_before" )
749- self .postnet = TFTacotronPostnet (config = config , dtype = tf .float32 , name = "postnet" )
759+ self .mel_dense = tf .keras .layers .Dense (
760+ units = config .num_mels , dtype = tf .float32 , name = "mel_before"
761+ )
762+ self .postnet = TFTacotronPostnet (
763+ config = config , dtype = tf .float32 , name = "postnet"
764+ )
750765
751766 self .setup_inference_fn ()
752767
0 commit comments