11# -*- coding: utf-8 -*-
2- # Copyright 2020 The Tacotron-2 Authors, Minh Nguyen (@dathudeptrai) and Eren Gölge (@erogol)
2+ # Copyright 2020 The Tacotron-2 Authors, Minh Nguyen (@dathudeptrai), Eren Gölge (@erogol) and Jae Yoo (@jaeyoo )
33#
44# Licensed under the Apache License, Version 2.0 (the "License");
55# you may not use this file except in compliance with the License.
2222
2323from tensorflow_addons .seq2seq import Sampler
2424from tensorflow_addons .seq2seq import BahdanauAttention
25- from tensorflow_addons .seq2seq import dynamic_decode
25+ # TODO: once https://github.com/tensorflow/addons/pull/1964 is fixed,
26+ # uncomment this line.
27+ # from tensorflow_addons.seq2seq import dynamic_decode
2628from tensorflow_addons .seq2seq import Decoder
29+ from tensorflow_tts .utils import dynamic_decode
2730
2831
2932def get_initializer (initializer_range = 0.02 ):
@@ -484,10 +487,15 @@ def call(self, inputs, training=False):
484487class TFTacotronDecoderCell (tf .keras .layers .AbstractRNNCell ):
485488 """Tacotron-2 custom decoder cell."""
486489
487- def __init__ (self , config , training , ** kwargs ):
490+ def __init__ (self ,
491+ config ,
492+ training ,
493+ enable_tflite_convertible = False ,
494+ ** kwargs ):
488495 """Init variables."""
489496 super ().__init__ (** kwargs )
490497 self .training = training
498+ self .enable_tflite_convertible = enable_tflite_convertible
491499 self .prenet = TFTacotronPrenet (config , name = "prenet" )
492500
493501 # define lstm cell on decoder.
@@ -563,9 +571,12 @@ def get_initial_state(self, batch_size):
563571 initial_state = self .attention_layer .get_initial_state (
564572 batch_size , size = self .alignment_size
565573 )
566- initial_alignment_history = tf .TensorArray (
567- dtype = tf .float32 , size = 0 , dynamic_size = True
568- )
574+ if self .enable_tflite_convertible :
575+ initial_alignment_history = ()
576+ else :
577+ initial_alignment_history = tf .TensorArray (
578+ dtype = tf .float32 , size = 0 , dynamic_size = True
579+ )
569580 return TFTacotronDecoderCellState (
570581 attention_lstm_state = initial_attention_lstm_cell_states ,
571582 decoder_lstms_state = initial_decoder_lstms_cell_states ,
@@ -594,7 +605,8 @@ def call(self, inputs, states):
594605
595606 # 3. compute context, alignment and cumulative alignment.
596607 prev_state = states .state
597- prev_alignment_history = states .alignment_history
608+ if not self .enable_tflite_convertible :
609+ prev_alignment_history = states .alignment_history
598610 prev_max_alignments = states .max_alignments
599611 context , alignments , state = self .attention_layer (
600612 [attention_lstm_output , prev_state , prev_max_alignments ],
@@ -615,7 +627,11 @@ def call(self, inputs, states):
615627 stop_tokens = self .stop_projection (stop_inputs )
616628
617629 # 6. save alignment history to visualize.
618- alignment_history = prev_alignment_history .write (states .time , alignments )
630+ if self .enable_tflite_convertible :
631+ alignment_history = ()
632+ else :
633+ alignment_history = prev_alignment_history .write (states .time ,
634+ alignments )
619635
620636 # 7. return new states.
621637 new_states = TFTacotronDecoderCellState (
@@ -634,11 +650,16 @@ def call(self, inputs, states):
634650class TFTacotronDecoder (Decoder ):
635651 """Tacotron-2 Decoder."""
636652
637- def __init__ (self , decoder_cell , decoder_sampler , output_layer = None ):
653+ def __init__ (self ,
654+ decoder_cell ,
655+ decoder_sampler ,
656+ output_layer = None ,
657+ enable_tflite_convertible = False ):
638658 """Initial variables."""
639659 self .cell = decoder_cell
640660 self .sampler = decoder_sampler
641661 self .output_layer = output_layer
662+ self .enable_tflite_convertible = enable_tflite_convertible
642663
643664 def setup_decoder_init_state (self , decoder_init_state ):
644665 self .initial_state = decoder_init_state
@@ -653,7 +674,9 @@ def output_size(self):
653674 lambda shape : tf .TensorShape (shape ), self .cell .output_size
654675 ),
655676 token_output = tf .TensorShape (self .sampler .reduction_factor ),
656- sample_id = self .sampler .sample_ids_shape ,
677+ sample_id = tf .TensorShape ([1 ]) \
678+ if self .enable_tflite_convertible \
679+ else self .sampler .sample_ids_shape # tf.TensorShape([])
657680 )
658681
659682 @property
@@ -688,16 +711,18 @@ def step(self, time, inputs, state, training=False):
688711class TFTacotron2 (tf .keras .Model ):
689712 """Tensorflow tacotron-2 model."""
690713
691- def __init__ (self , config , training , ** kwargs ):
714+ def __init__ (self , config , training , enable_tflite_convertible = False , ** kwargs ):
692715 """Initalize tacotron-2 layers."""
693716 super ().__init__ (self , ** kwargs )
694717 self .encoder = TFTacotronEncoder (config , name = "encoder" )
695718 self .decoder_cell = TFTacotronDecoderCell (
696- config , training = training , name = "decoder_cell"
719+ config , training = training , name = "decoder_cell" ,
720+ enable_tflite_convertible = enable_tflite_convertible
697721 )
698722 self .decoder = TFTacotronDecoder (
699723 self .decoder_cell ,
700724 TrainingSampler (config ) if training is True else TestingSampler (config ),
725+ enable_tflite_convertible = enable_tflite_convertible
701726 )
702727 self .postnet = TFTacotronPostnet (config , name = "post_net" )
703728 self .post_projection = tf .keras .layers .Dense (
@@ -707,6 +732,7 @@ def __init__(self, config, training, **kwargs):
707732 self .config = config
708733 self .use_window_mask = False
709734 self .maximum_iterations = 4000
735+ self .enable_tflite_convertible = enable_tflite_convertible
710736
711737 def setup_window (self , win_front , win_back ):
712738 """Call only for inference."""
@@ -788,7 +814,9 @@ def call(
788814 (frames_prediction , stop_token_prediction , _ ),
789815 final_decoder_state ,
790816 _ ,
791- ) = dynamic_decode (self .decoder , maximum_iterations = maximum_iterations )
817+ ) = dynamic_decode (self .decoder ,
818+ maximum_iterations = maximum_iterations ,
819+ enable_tflite_convertible = self .enable_tflite_convertible )
792820
793821 decoder_output = tf .reshape (
794822 frames_prediction , [batch_size , - 1 , self .config .n_mels ]
@@ -800,9 +828,20 @@ def call(
800828
801829 mel_outputs = decoder_output + residual_projection
802830
803- alignment_history = tf .transpose (
804- final_decoder_state .alignment_history .stack (), [1 , 2 , 0 ]
805- )
831+ if self .enable_tflite_convertible :
832+ mask = tf .math .not_equal (
833+ tf .cast (tf .reduce_sum (tf .abs (decoder_output ), axis = - 1 ),
834+ dtype = tf .int32 ),
835+ 0 )
836+ decoder_output = tf .expand_dims (
837+ tf .boolean_mask (decoder_output , mask ), axis = 0 )
838+ mel_outputs = tf .expand_dims (
839+ tf .boolean_mask (mel_outputs , mask ), axis = 0 )
840+ alignment_history = ()
841+ else :
842+ alignment_history = tf .transpose (
843+ final_decoder_state .alignment_history .stack (), [1 , 2 , 0 ]
844+ )
806845
807846 return decoder_output , mel_outputs , stop_token_prediction , alignment_history
808847
@@ -868,8 +907,85 @@ def inference(self, input_ids, input_lengths, speaker_ids):
868907
869908 mel_outputs = decoder_output + residual_projection
870909
871- alignment_history = tf .transpose (
872- final_decoder_state .alignment_history .stack (), [1 , 2 , 0 ]
910+ return decoder_output , mel_outputs , stop_token_prediction , alignment_history
911+
912+ @tf .function (
913+ experimental_relax_shapes = True ,
914+ input_signature = [
915+ tf .TensorSpec ([1 , None ], dtype = tf .int32 ),
916+ tf .TensorSpec ([1 ,], dtype = tf .int32 ),
917+ tf .TensorSpec ([1 ,], dtype = tf .int32 ),
918+ ],
919+ )
920+ def inference_tflite (self , input_ids , input_lengths , speaker_ids ):
921+ """Call logic."""
922+ # create input-mask based on input_lengths
923+ input_mask = tf .sequence_mask (
924+ input_lengths ,
925+ maxlen = tf .reduce_max (input_lengths ),
926+ name = "input_sequence_masks" ,
927+ )
928+
929+ # Encoder Step.
930+ encoder_hidden_states = self .encoder (
931+ [input_ids , speaker_ids , input_mask ], training = False
873932 )
874933
934+ batch_size = tf .shape (encoder_hidden_states )[0 ]
935+ alignment_size = tf .shape (encoder_hidden_states )[1 ]
936+
937+ # Setup some initial placeholders for decoder step. Include:
938+ # 1. batch_size for inference.
939+ # 2. alignment_size for attention size.
940+ # 3. initial state for decoder cell.
941+ # 4. memory (encoder hidden state) for attention mechanism.
942+ # 5. window front/back to solve long sentence synthesize problems. (call after setup memory.)
943+ self .decoder .sampler .set_batch_size (batch_size )
944+ self .decoder .cell .set_alignment_size (alignment_size )
945+ self .decoder .setup_decoder_init_state (
946+ self .decoder .cell .get_initial_state (batch_size )
947+ )
948+ self .decoder .cell .attention_layer .setup_memory (
949+ memory = encoder_hidden_states ,
950+ memory_sequence_length = input_lengths , # use for mask attention.
951+ )
952+ if self .use_window_mask :
953+ self .decoder .cell .attention_layer .setup_window (
954+ win_front = self .win_front , win_back = self .win_back
955+ )
956+
957+ # run decode step.
958+ (
959+ (frames_prediction , stop_token_prediction , _ ),
960+ final_decoder_state ,
961+ _ ,
962+ ) = dynamic_decode (self .decoder ,
963+ maximum_iterations = self .maximum_iterations ,
964+ enable_tflite_convertible = self .enable_tflite_convertible )
965+
966+ decoder_output = tf .reshape (
967+ frames_prediction , [batch_size , - 1 , self .config .n_mels ]
968+ )
969+ stop_token_prediction = tf .reshape (stop_token_prediction , [batch_size , - 1 ])
970+
971+ residual = self .postnet (decoder_output , training = False )
972+ residual_projection = self .post_projection (residual )
973+
974+ mel_outputs = decoder_output + residual_projection
975+
976+ if self .enable_tflite_convertible :
977+ mask = tf .math .not_equal (
978+ tf .cast (tf .reduce_sum (tf .abs (decoder_output ), axis = - 1 ),
979+ dtype = tf .int32 ),
980+ 0 )
981+ decoder_output = tf .expand_dims (
982+ tf .boolean_mask (decoder_output , mask ), axis = 0 )
983+ mel_outputs = tf .expand_dims (
984+ tf .boolean_mask (mel_outputs , mask ), axis = 0 )
985+ alignment_history = ()
986+ else :
987+ alignment_history = tf .transpose (
988+ final_decoder_state .alignment_history .stack (), [1 , 2 , 0 ]
989+ )
990+
875991 return decoder_output , mel_outputs , stop_token_prediction , alignment_history
0 commit comments