Skip to content

Commit bd7c4d8

Browse files
authored
Merge pull request #83 from jaeyoo/patch-1
📝 🚀 Add `inference_tflite()` in `TFTacotron2`
2 parents bf2e4c3 + e58b311 commit bd7c4d8

File tree

3 files changed

+443
-18
lines changed

3 files changed

+443
-18
lines changed

tensorflow_tts/models/tacotron2.py

Lines changed: 134 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
# Copyright 2020 The Tacotron-2 Authors, Minh Nguyen (@dathudeptrai) and Eren Gölge (@erogol)
2+
# Copyright 2020 The Tacotron-2 Authors, Minh Nguyen (@dathudeptrai), Eren Gölge (@erogol) and Jae Yoo (@jaeyoo)
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -22,8 +22,11 @@
2222

2323
from tensorflow_addons.seq2seq import Sampler
2424
from tensorflow_addons.seq2seq import BahdanauAttention
25-
from tensorflow_addons.seq2seq import dynamic_decode
25+
# TODO: once https://github.com/tensorflow/addons/pull/1964 is fixed,
26+
# uncomment this line.
27+
# from tensorflow_addons.seq2seq import dynamic_decode
2628
from tensorflow_addons.seq2seq import Decoder
29+
from tensorflow_tts.utils import dynamic_decode
2730

2831

2932
def get_initializer(initializer_range=0.02):
@@ -484,10 +487,15 @@ def call(self, inputs, training=False):
484487
class TFTacotronDecoderCell(tf.keras.layers.AbstractRNNCell):
485488
"""Tacotron-2 custom decoder cell."""
486489

487-
def __init__(self, config, training, **kwargs):
490+
def __init__(self,
491+
config,
492+
training,
493+
enable_tflite_convertible = False,
494+
**kwargs):
488495
"""Init variables."""
489496
super().__init__(**kwargs)
490497
self.training = training
498+
self.enable_tflite_convertible = enable_tflite_convertible
491499
self.prenet = TFTacotronPrenet(config, name="prenet")
492500

493501
# define lstm cell on decoder.
@@ -563,9 +571,12 @@ def get_initial_state(self, batch_size):
563571
initial_state = self.attention_layer.get_initial_state(
564572
batch_size, size=self.alignment_size
565573
)
566-
initial_alignment_history = tf.TensorArray(
567-
dtype=tf.float32, size=0, dynamic_size=True
568-
)
574+
if self.enable_tflite_convertible:
575+
initial_alignment_history = ()
576+
else:
577+
initial_alignment_history = tf.TensorArray(
578+
dtype=tf.float32, size=0, dynamic_size=True
579+
)
569580
return TFTacotronDecoderCellState(
570581
attention_lstm_state=initial_attention_lstm_cell_states,
571582
decoder_lstms_state=initial_decoder_lstms_cell_states,
@@ -594,7 +605,8 @@ def call(self, inputs, states):
594605

595606
# 3. compute context, alignment and cumulative alignment.
596607
prev_state = states.state
597-
prev_alignment_history = states.alignment_history
608+
if not self.enable_tflite_convertible:
609+
prev_alignment_history = states.alignment_history
598610
prev_max_alignments = states.max_alignments
599611
context, alignments, state = self.attention_layer(
600612
[attention_lstm_output, prev_state, prev_max_alignments],
@@ -615,7 +627,11 @@ def call(self, inputs, states):
615627
stop_tokens = self.stop_projection(stop_inputs)
616628

617629
# 6. save alignment history to visualize.
618-
alignment_history = prev_alignment_history.write(states.time, alignments)
630+
if self.enable_tflite_convertible:
631+
alignment_history = ()
632+
else:
633+
alignment_history = prev_alignment_history.write(states.time,
634+
alignments)
619635

620636
# 7. return new states.
621637
new_states = TFTacotronDecoderCellState(
@@ -634,11 +650,16 @@ def call(self, inputs, states):
634650
class TFTacotronDecoder(Decoder):
635651
"""Tacotron-2 Decoder."""
636652

637-
def __init__(self, decoder_cell, decoder_sampler, output_layer=None):
653+
def __init__(self,
654+
decoder_cell,
655+
decoder_sampler,
656+
output_layer=None,
657+
enable_tflite_convertible=False):
638658
"""Initial variables."""
639659
self.cell = decoder_cell
640660
self.sampler = decoder_sampler
641661
self.output_layer = output_layer
662+
self.enable_tflite_convertible = enable_tflite_convertible
642663

643664
def setup_decoder_init_state(self, decoder_init_state):
644665
self.initial_state = decoder_init_state
@@ -653,7 +674,9 @@ def output_size(self):
653674
lambda shape: tf.TensorShape(shape), self.cell.output_size
654675
),
655676
token_output=tf.TensorShape(self.sampler.reduction_factor),
656-
sample_id=self.sampler.sample_ids_shape,
677+
sample_id=tf.TensorShape([1]) \
678+
if self.enable_tflite_convertible \
679+
else self.sampler.sample_ids_shape # tf.TensorShape([])
657680
)
658681

659682
@property
@@ -688,16 +711,18 @@ def step(self, time, inputs, state, training=False):
688711
class TFTacotron2(tf.keras.Model):
689712
"""Tensorflow tacotron-2 model."""
690713

691-
def __init__(self, config, training, **kwargs):
714+
def __init__(self, config, training, enable_tflite_convertible = False, **kwargs):
692715
"""Initalize tacotron-2 layers."""
693716
super().__init__(self, **kwargs)
694717
self.encoder = TFTacotronEncoder(config, name="encoder")
695718
self.decoder_cell = TFTacotronDecoderCell(
696-
config, training=training, name="decoder_cell"
719+
config, training=training, name="decoder_cell",
720+
enable_tflite_convertible = enable_tflite_convertible
697721
)
698722
self.decoder = TFTacotronDecoder(
699723
self.decoder_cell,
700724
TrainingSampler(config) if training is True else TestingSampler(config),
725+
enable_tflite_convertible = enable_tflite_convertible
701726
)
702727
self.postnet = TFTacotronPostnet(config, name="post_net")
703728
self.post_projection = tf.keras.layers.Dense(
@@ -707,6 +732,7 @@ def __init__(self, config, training, **kwargs):
707732
self.config = config
708733
self.use_window_mask = False
709734
self.maximum_iterations = 4000
735+
self.enable_tflite_convertible = enable_tflite_convertible
710736

711737
def setup_window(self, win_front, win_back):
712738
"""Call only for inference."""
@@ -788,7 +814,9 @@ def call(
788814
(frames_prediction, stop_token_prediction, _),
789815
final_decoder_state,
790816
_,
791-
) = dynamic_decode(self.decoder, maximum_iterations=maximum_iterations)
817+
) = dynamic_decode(self.decoder,
818+
maximum_iterations=maximum_iterations,
819+
enable_tflite_convertible=self.enable_tflite_convertible)
792820

793821
decoder_output = tf.reshape(
794822
frames_prediction, [batch_size, -1, self.config.n_mels]
@@ -800,9 +828,20 @@ def call(
800828

801829
mel_outputs = decoder_output + residual_projection
802830

803-
alignment_history = tf.transpose(
804-
final_decoder_state.alignment_history.stack(), [1, 2, 0]
805-
)
831+
if self.enable_tflite_convertible:
832+
mask = tf.math.not_equal(
833+
tf.cast(tf.reduce_sum(tf.abs(decoder_output), axis=-1),
834+
dtype=tf.int32),
835+
0)
836+
decoder_output = tf.expand_dims(
837+
tf.boolean_mask(decoder_output, mask), axis=0)
838+
mel_outputs = tf.expand_dims(
839+
tf.boolean_mask(mel_outputs, mask), axis=0)
840+
alignment_history = ()
841+
else:
842+
alignment_history = tf.transpose(
843+
final_decoder_state.alignment_history.stack(), [1, 2, 0]
844+
)
806845

807846
return decoder_output, mel_outputs, stop_token_prediction, alignment_history
808847

@@ -868,8 +907,85 @@ def inference(self, input_ids, input_lengths, speaker_ids):
868907

869908
mel_outputs = decoder_output + residual_projection
870909

871-
alignment_history = tf.transpose(
872-
final_decoder_state.alignment_history.stack(), [1, 2, 0]
910+
return decoder_output, mel_outputs, stop_token_prediction, alignment_history
911+
912+
@tf.function(
913+
experimental_relax_shapes=True,
914+
input_signature=[
915+
tf.TensorSpec([1, None], dtype=tf.int32),
916+
tf.TensorSpec([1,], dtype=tf.int32),
917+
tf.TensorSpec([1,], dtype=tf.int32),
918+
],
919+
)
920+
def inference_tflite(self, input_ids, input_lengths, speaker_ids):
921+
"""Call logic."""
922+
# create input-mask based on input_lengths
923+
input_mask = tf.sequence_mask(
924+
input_lengths,
925+
maxlen=tf.reduce_max(input_lengths),
926+
name="input_sequence_masks",
927+
)
928+
929+
# Encoder Step.
930+
encoder_hidden_states = self.encoder(
931+
[input_ids, speaker_ids, input_mask], training=False
873932
)
874933

934+
batch_size = tf.shape(encoder_hidden_states)[0]
935+
alignment_size = tf.shape(encoder_hidden_states)[1]
936+
937+
# Setup some initial placeholders for decoder step. Include:
938+
# 1. batch_size for inference.
939+
# 2. alignment_size for attention size.
940+
# 3. initial state for decoder cell.
941+
# 4. memory (encoder hidden state) for attention mechanism.
942+
# 5. window front/back to solve long sentence synthesize problems. (call after setup memory.)
943+
self.decoder.sampler.set_batch_size(batch_size)
944+
self.decoder.cell.set_alignment_size(alignment_size)
945+
self.decoder.setup_decoder_init_state(
946+
self.decoder.cell.get_initial_state(batch_size)
947+
)
948+
self.decoder.cell.attention_layer.setup_memory(
949+
memory=encoder_hidden_states,
950+
memory_sequence_length=input_lengths, # use for mask attention.
951+
)
952+
if self.use_window_mask:
953+
self.decoder.cell.attention_layer.setup_window(
954+
win_front=self.win_front, win_back=self.win_back
955+
)
956+
957+
# run decode step.
958+
(
959+
(frames_prediction, stop_token_prediction, _),
960+
final_decoder_state,
961+
_,
962+
) = dynamic_decode(self.decoder,
963+
maximum_iterations=self.maximum_iterations,
964+
enable_tflite_convertible=self.enable_tflite_convertible)
965+
966+
decoder_output = tf.reshape(
967+
frames_prediction, [batch_size, -1, self.config.n_mels]
968+
)
969+
stop_token_prediction = tf.reshape(stop_token_prediction, [batch_size, -1])
970+
971+
residual = self.postnet(decoder_output, training=False)
972+
residual_projection = self.post_projection(residual)
973+
974+
mel_outputs = decoder_output + residual_projection
975+
976+
if self.enable_tflite_convertible:
977+
mask = tf.math.not_equal(
978+
tf.cast(tf.reduce_sum(tf.abs(decoder_output), axis=-1),
979+
dtype=tf.int32),
980+
0)
981+
decoder_output = tf.expand_dims(
982+
tf.boolean_mask(decoder_output, mask), axis=0)
983+
mel_outputs = tf.expand_dims(
984+
tf.boolean_mask(mel_outputs, mask), axis=0)
985+
alignment_history = ()
986+
else:
987+
alignment_history = tf.transpose(
988+
final_decoder_state.alignment_history.stack(), [1, 2, 0]
989+
)
990+
875991
return decoder_output, mel_outputs, stop_token_prediction, alignment_history

tensorflow_tts/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from tensorflow_tts.utils.cleaners import transliteration_cleaners
1313
from tensorflow_tts.utils.cleaners import english_cleaners
1414

15+
from tensorflow_tts.utils.decoder import dynamic_decode
16+
1517
from tensorflow_tts.utils.number_norm import normalize_numbers
1618

1719
from tensorflow_tts.utils.outliers import remove_outlier

0 commit comments

Comments
 (0)