Skip to content

Commit 300b401

Browse files
authored
Merge pull request #84 from jaeyoo/patch-2
🚀 📝 Add TFLite-convertible TFFastSpeech
2 parents b8c22a6 + a8fc65c commit 300b401

File tree

1 file changed

+124
-50
lines changed

1 file changed

+124
-50
lines changed

tensorflow_tts/models/fastspeech.py

Lines changed: 124 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -573,14 +573,14 @@ def call(self, inputs, training=False):
573573
class TFFastSpeechLengthRegulator(tf.keras.layers.Layer):
574574
"""FastSpeech lengthregulator module."""
575575

576-
def __init__(self, config, **kwargs):
576+
def __init__(self, config, enable_tflite_convertible = False, **kwargs):
577577
"""Init variables."""
578578
super().__init__(**kwargs)
579579
self.config = config
580+
self.enable_tflite_convertible = enable_tflite_convertible
580581

581582
def call(self, inputs, training=False):
582583
"""Call logic.
583-
584584
Args:
585585
1. encoder_hidden_states, Tensor (float32) shape [batch_size, length, hidden_size]
586586
2. durations_gt, Tensor (float32/int32) shape [batch_size, length]
@@ -601,75 +601,93 @@ def _length_regulator(self, encoder_hidden_states, durations_gt):
601601
hidden_size = input_shape[-1]
602602

603603
# initialize output hidden states and encoder masking.
604-
outputs = tf.zeros(shape=[0, max_durations, hidden_size], dtype=tf.float32)
605-
encoder_masks = tf.zeros(shape=[0, max_durations], dtype=tf.int32)
606-
607-
def condition(
608-
i,
609-
batch_size,
610-
outputs,
611-
encoder_masks,
612-
encoder_hidden_states,
613-
durations_gt,
614-
max_durations,
615-
):
616-
return tf.less(i, batch_size)
617-
618-
def body(
619-
i,
620-
batch_size,
621-
outputs,
622-
encoder_masks,
623-
encoder_hidden_states,
624-
durations_gt,
625-
max_durations,
626-
):
627-
repeats = durations_gt[i]
604+
if self.enable_tflite_convertible:
605+
# There is only 1 batch in inference, so we don't have to use
606+
# `tf.While` op with 3-D output tensor.
607+
repeats = durations_gt[0]
628608
real_length = tf.reduce_sum(repeats)
629609
pad_size = max_durations - real_length
610+
# masks : [max_durations]
630611
masks = tf.sequence_mask([real_length], max_durations, dtype=tf.int32)
631612
repeat_encoder_hidden_states = tf.repeat(
632-
encoder_hidden_states[i], repeats=repeats, axis=0
613+
encoder_hidden_states[0], repeats=repeats, axis=0
633614
)
634615
repeat_encoder_hidden_states = tf.expand_dims(
635616
tf.pad(repeat_encoder_hidden_states, [[0, pad_size], [0, 0]]), 0
636-
) # [1, max_durations, hidden_size]
637-
outputs = tf.concat([outputs, repeat_encoder_hidden_states], axis=0)
638-
encoder_masks = tf.concat([encoder_masks, masks], axis=0)
639-
return [
640-
i + 1,
617+
) # [1, max_durations, hidden_size]
618+
619+
outputs = repeat_encoder_hidden_states
620+
encoder_masks = masks
621+
else:
622+
outputs = tf.zeros(shape=[0, max_durations, hidden_size], dtype=tf.float32)
623+
encoder_masks = tf.zeros(shape=[0, max_durations], dtype=tf.int32)
624+
625+
def condition(
626+
i,
641627
batch_size,
642628
outputs,
643629
encoder_masks,
644630
encoder_hidden_states,
645631
durations_gt,
646632
max_durations,
647-
]
633+
):
634+
return tf.less(i, batch_size)
648635

649-
# initialize iteration i.
650-
i = tf.constant(0, dtype=tf.int32)
651-
_, _, outputs, encoder_masks, _, _, _, = tf.while_loop(
652-
condition,
653-
body,
654-
[
636+
def body(
655637
i,
656638
batch_size,
657639
outputs,
658640
encoder_masks,
659641
encoder_hidden_states,
660642
durations_gt,
661643
max_durations,
662-
],
663-
shape_invariants=[
664-
i.get_shape(),
665-
batch_size.get_shape(),
666-
tf.TensorShape([None, None, self.config.hidden_size]),
667-
tf.TensorShape([None, None]),
668-
encoder_hidden_states.get_shape(),
669-
durations_gt.get_shape(),
670-
max_durations.get_shape(),
671-
],
672-
)
644+
):
645+
repeats = durations_gt[i]
646+
real_length = tf.reduce_sum(repeats)
647+
pad_size = max_durations - real_length
648+
masks = tf.sequence_mask([real_length], max_durations, dtype=tf.int32)
649+
repeat_encoder_hidden_states = tf.repeat(
650+
encoder_hidden_states[i], repeats=repeats, axis=0
651+
)
652+
repeat_encoder_hidden_states = tf.expand_dims(
653+
tf.pad(repeat_encoder_hidden_states, [[0, pad_size], [0, 0]]), 0
654+
) # [1, max_durations, hidden_size]
655+
outputs = tf.concat([outputs, repeat_encoder_hidden_states], axis=0)
656+
encoder_masks = tf.concat([encoder_masks, masks], axis=0)
657+
return [
658+
i + 1,
659+
batch_size,
660+
outputs,
661+
encoder_masks,
662+
encoder_hidden_states,
663+
durations_gt,
664+
max_durations,
665+
]
666+
667+
# initialize iteration i.
668+
i = tf.constant(0, dtype=tf.int32)
669+
_, _, outputs, encoder_masks, _, _, _, = tf.while_loop(
670+
condition,
671+
body,
672+
[
673+
i,
674+
batch_size,
675+
outputs,
676+
encoder_masks,
677+
encoder_hidden_states,
678+
durations_gt,
679+
max_durations,
680+
],
681+
shape_invariants=[
682+
i.get_shape(),
683+
batch_size.get_shape(),
684+
tf.TensorShape([None, None, self.config.hidden_size]),
685+
tf.TensorShape([None, None]),
686+
encoder_hidden_states.get_shape(),
687+
durations_gt.get_shape(),
688+
max_durations.get_shape(),
689+
],
690+
)
673691

674692
return outputs, encoder_masks
675693

@@ -799,3 +817,59 @@ def inference(self, input_ids, attention_mask, speaker_ids, speed_ratios):
799817

800818
outputs = (mel_before, mel_after, duration_outputs)
801819
return outputs
820+
821+
@tf.function(
822+
experimental_relax_shapes=True,
823+
input_signature=[
824+
tf.TensorSpec(shape=[1, None], dtype=tf.int32),
825+
tf.TensorSpec(shape=[1, None], dtype=tf.bool),
826+
tf.TensorSpec(shape=[1,], dtype=tf.int32),
827+
tf.TensorSpec(shape=[1,], dtype=tf.float32),
828+
],
829+
)
830+
def inference_tflite(self, input_ids, attention_mask, speaker_ids, speed_ratios):
831+
"""Call logic."""
832+
embedding_output = self.embeddings([input_ids, speaker_ids], training=False)
833+
encoder_output = self.encoder(
834+
[embedding_output, attention_mask], training=False
835+
)
836+
last_encoder_hidden_states = encoder_output[0]
837+
838+
# duration predictor, here use last_encoder_hidden_states, u can use more hidden_states layers
839+
# rather than just use last_hidden_states of encoder for duration_predictor.
840+
duration_outputs = self.duration_predictor(
841+
[last_encoder_hidden_states, attention_mask]
842+
) # [batch_size, length]
843+
duration_outputs = tf.math.exp(duration_outputs) - 1.0
844+
845+
if speed_ratios is None:
846+
speed_ratios = tf.convert_to_tensor(np.array([1.0]), dtype=tf.float32)
847+
848+
duration_outputs = tf.cast(
849+
tf.math.round(duration_outputs * speed_ratios), tf.int32
850+
)
851+
852+
length_regulator_outputs, encoder_masks = self.length_regulator(
853+
[last_encoder_hidden_states, duration_outputs], training=False
854+
)
855+
856+
# create decoder positional embedding
857+
decoder_pos = tf.range(
858+
1, tf.shape(length_regulator_outputs)[1] + 1, dtype=tf.int32
859+
)
860+
masked_decoder_pos = tf.expand_dims(decoder_pos, 0) * encoder_masks
861+
862+
decoder_output = self.decoder(
863+
[length_regulator_outputs, speaker_ids, encoder_masks, masked_decoder_pos],
864+
training=False,
865+
)
866+
last_decoder_hidden_states = decoder_output[0]
867+
868+
# here u can use sum or concat more than 1 hidden states layers from decoder.
869+
mel_before = self.mel_dense(last_decoder_hidden_states)
870+
mel_after = (
871+
self.postnet([mel_before, encoder_masks], training=False) + mel_before
872+
)
873+
874+
outputs = (mel_before, mel_after, duration_outputs)
875+
return outputs

0 commit comments

Comments
 (0)