@@ -573,7 +573,7 @@ def call(self, inputs, training=False):
573573class TFFastSpeechLengthRegulator (tf .keras .layers .Layer ):
574574 """FastSpeech lengthregulator module."""
575575
576- def __init__ (self , config , enable_tflite_convertible = False , ** kwargs ):
576+ def __init__ (self , config , enable_tflite_convertible = False , ** kwargs ):
577577 """Init variables."""
578578 super ().__init__ (** kwargs )
579579 self .config = config
@@ -614,7 +614,7 @@ def _length_regulator(self, encoder_hidden_states, durations_gt):
614614 )
615615 repeat_encoder_hidden_states = tf .expand_dims (
616616 tf .pad (repeat_encoder_hidden_states , [[0 , pad_size ], [0 , 0 ]]), 0
617- ) # [1, max_durations, hidden_size]
617+ ) # [1, max_durations, hidden_size]
618618
619619 outputs = repeat_encoder_hidden_states
620620 encoder_masks = masks
@@ -695,7 +695,7 @@ def body(
695695class TFFastSpeech (tf .keras .Model ):
696696 """TF Fastspeech module."""
697697
698- def __init__ (self , config , ** kwargs ):
698+ def __init__ (self , config , enable_tflite_convertible = False , ** kwargs ):
699699 """Init layers for fastspeech."""
700700 super ().__init__ (** kwargs )
701701 self .embeddings = TFFastSpeechEmbeddings (config , name = "embeddings" )
@@ -704,12 +704,16 @@ def __init__(self, config, **kwargs):
704704 config , name = "duration_predictor"
705705 )
706706 self .length_regulator = TFFastSpeechLengthRegulator (
707- config , name = "length_regulator"
707+ config ,
708+ enable_tflite_convertible = enable_tflite_convertible ,
709+ name = "length_regulator"
708710 )
709711 self .decoder = TFFastSpeechDecoder (config , name = "decoder" )
710712 self .mel_dense = tf .keras .layers .Dense (units = config .num_mels , name = "mel_before" )
711713 self .postnet = TFTacotronPostnet (config = config , name = "postnet" )
712714
715+ self .enable_tflite_convertible = enable_tflite_convertible
716+
713717 def _build (self ):
714718 """Dummy input for building model."""
715719 # fake inputs
@@ -767,8 +771,8 @@ def call(
767771 input_signature = [
768772 tf .TensorSpec (shape = [None , None ], dtype = tf .int32 ),
769773 tf .TensorSpec (shape = [None , None ], dtype = tf .bool ),
770- tf .TensorSpec (shape = [None ,], dtype = tf .int32 ),
771- tf .TensorSpec (shape = [None ,], dtype = tf .float32 ),
774+ tf .TensorSpec (shape = [None , ], dtype = tf .int32 ),
775+ tf .TensorSpec (shape = [None , ], dtype = tf .float32 ),
772776 ],
773777 )
774778 def inference (self , input_ids , attention_mask , speaker_ids , speed_ratios ):
@@ -823,8 +827,8 @@ def inference(self, input_ids, attention_mask, speaker_ids, speed_ratios):
823827 input_signature = [
824828 tf .TensorSpec (shape = [1 , None ], dtype = tf .int32 ),
825829 tf .TensorSpec (shape = [1 , None ], dtype = tf .bool ),
826- tf .TensorSpec (shape = [1 ,], dtype = tf .int32 ),
827- tf .TensorSpec (shape = [1 ,], dtype = tf .float32 ),
830+ tf .TensorSpec (shape = [1 , ], dtype = tf .int32 ),
831+ tf .TensorSpec (shape = [1 , ], dtype = tf .float32 ),
828832 ],
829833 )
830834 def inference_tflite (self , input_ids , attention_mask , speaker_ids , speed_ratios ):
0 commit comments