Skip to content

Commit 457838e

Browse files
committed
🔧 Fix missing tflite flag on TFFastSpeech
1 parent 48d5f79 commit 457838e

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

tensorflow_tts/models/fastspeech.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ def call(self, inputs, training=False):
573573
class TFFastSpeechLengthRegulator(tf.keras.layers.Layer):
574574
"""FastSpeech lengthregulator module."""
575575

576-
def __init__(self, config, enable_tflite_convertible = False, **kwargs):
576+
def __init__(self, config, enable_tflite_convertible=False, **kwargs):
577577
"""Init variables."""
578578
super().__init__(**kwargs)
579579
self.config = config
@@ -614,7 +614,7 @@ def _length_regulator(self, encoder_hidden_states, durations_gt):
614614
)
615615
repeat_encoder_hidden_states = tf.expand_dims(
616616
tf.pad(repeat_encoder_hidden_states, [[0, pad_size], [0, 0]]), 0
617-
) # [1, max_durations, hidden_size]
617+
) # [1, max_durations, hidden_size]
618618

619619
outputs = repeat_encoder_hidden_states
620620
encoder_masks = masks
@@ -695,7 +695,7 @@ def body(
695695
class TFFastSpeech(tf.keras.Model):
696696
"""TF Fastspeech module."""
697697

698-
def __init__(self, config, **kwargs):
698+
def __init__(self, config, enable_tflite_convertible=False, **kwargs):
699699
"""Init layers for fastspeech."""
700700
super().__init__(**kwargs)
701701
self.embeddings = TFFastSpeechEmbeddings(config, name="embeddings")
@@ -704,12 +704,16 @@ def __init__(self, config, **kwargs):
704704
config, name="duration_predictor"
705705
)
706706
self.length_regulator = TFFastSpeechLengthRegulator(
707-
config, name="length_regulator"
707+
config,
708+
enable_tflite_convertible=enable_tflite_convertible,
709+
name="length_regulator"
708710
)
709711
self.decoder = TFFastSpeechDecoder(config, name="decoder")
710712
self.mel_dense = tf.keras.layers.Dense(units=config.num_mels, name="mel_before")
711713
self.postnet = TFTacotronPostnet(config=config, name="postnet")
712714

715+
self.enable_tflite_convertible = enable_tflite_convertible
716+
713717
def _build(self):
714718
"""Dummy input for building model."""
715719
# fake inputs
@@ -767,8 +771,8 @@ def call(
767771
input_signature=[
768772
tf.TensorSpec(shape=[None, None], dtype=tf.int32),
769773
tf.TensorSpec(shape=[None, None], dtype=tf.bool),
770-
tf.TensorSpec(shape=[None,], dtype=tf.int32),
771-
tf.TensorSpec(shape=[None,], dtype=tf.float32),
774+
tf.TensorSpec(shape=[None, ], dtype=tf.int32),
775+
tf.TensorSpec(shape=[None, ], dtype=tf.float32),
772776
],
773777
)
774778
def inference(self, input_ids, attention_mask, speaker_ids, speed_ratios):
@@ -823,8 +827,8 @@ def inference(self, input_ids, attention_mask, speaker_ids, speed_ratios):
823827
input_signature=[
824828
tf.TensorSpec(shape=[1, None], dtype=tf.int32),
825829
tf.TensorSpec(shape=[1, None], dtype=tf.bool),
826-
tf.TensorSpec(shape=[1,], dtype=tf.int32),
827-
tf.TensorSpec(shape=[1,], dtype=tf.float32),
830+
tf.TensorSpec(shape=[1, ], dtype=tf.int32),
831+
tf.TensorSpec(shape=[1, ], dtype=tf.float32),
828832
],
829833
)
830834
def inference_tflite(self, input_ids, attention_mask, speaker_ids, speed_ratios):

0 commit comments

Comments
 (0)