@@ -573,14 +573,14 @@ def call(self, inputs, training=False):
573573class TFFastSpeechLengthRegulator (tf .keras .layers .Layer ):
574574 """FastSpeech lengthregulator module."""
575575
576- def __init__ (self , config , ** kwargs ):
576+ def __init__ (self , config , enable_tflite_convertible = False , ** kwargs ):
577577 """Init variables."""
578578 super ().__init__ (** kwargs )
579579 self .config = config
580+ self .enable_tflite_convertible = enable_tflite_convertible
580581
581582 def call (self , inputs , training = False ):
582583 """Call logic.
583-
584584 Args:
585585 1. encoder_hidden_states, Tensor (float32) shape [batch_size, length, hidden_size]
586586 2. durations_gt, Tensor (float32/int32) shape [batch_size, length]
@@ -601,83 +601,101 @@ def _length_regulator(self, encoder_hidden_states, durations_gt):
601601 hidden_size = input_shape [- 1 ]
602602
603603 # initialize output hidden states and encoder masking.
604- outputs = tf .zeros (shape = [0 , max_durations , hidden_size ], dtype = tf .float32 )
605- encoder_masks = tf .zeros (shape = [0 , max_durations ], dtype = tf .int32 )
606-
607- def condition (
608- i ,
609- batch_size ,
610- outputs ,
611- encoder_masks ,
612- encoder_hidden_states ,
613- durations_gt ,
614- max_durations ,
615- ):
616- return tf .less (i , batch_size )
617-
618- def body (
619- i ,
620- batch_size ,
621- outputs ,
622- encoder_masks ,
623- encoder_hidden_states ,
624- durations_gt ,
625- max_durations ,
626- ):
627- repeats = durations_gt [i ]
604+ if self .enable_tflite_convertible :
605+ # There is only 1 batch in inference, so we don't have to use
606+ # `tf.While` op with 3-D output tensor.
607+ repeats = durations_gt [0 ]
628608 real_length = tf .reduce_sum (repeats )
629609 pad_size = max_durations - real_length
610+ # masks : [max_durations]
630611 masks = tf .sequence_mask ([real_length ], max_durations , dtype = tf .int32 )
631612 repeat_encoder_hidden_states = tf .repeat (
632- encoder_hidden_states [i ], repeats = repeats , axis = 0
613+ encoder_hidden_states [0 ], repeats = repeats , axis = 0
633614 )
634615 repeat_encoder_hidden_states = tf .expand_dims (
635616 tf .pad (repeat_encoder_hidden_states , [[0 , pad_size ], [0 , 0 ]]), 0
636617 ) # [1, max_durations, hidden_size]
637- outputs = tf .concat ([outputs , repeat_encoder_hidden_states ], axis = 0 )
638- encoder_masks = tf .concat ([encoder_masks , masks ], axis = 0 )
639- return [
640- i + 1 ,
618+
619+ outputs = repeat_encoder_hidden_states
620+ encoder_masks = masks
621+ else :
622+ outputs = tf .zeros (shape = [0 , max_durations , hidden_size ], dtype = tf .float32 )
623+ encoder_masks = tf .zeros (shape = [0 , max_durations ], dtype = tf .int32 )
624+
625+ def condition (
626+ i ,
641627 batch_size ,
642628 outputs ,
643629 encoder_masks ,
644630 encoder_hidden_states ,
645631 durations_gt ,
646632 max_durations ,
647- ]
633+ ):
634+ return tf .less (i , batch_size )
648635
649- # initialize iteration i.
650- i = tf .constant (0 , dtype = tf .int32 )
651- _ , _ , outputs , encoder_masks , _ , _ , _ , = tf .while_loop (
652- condition ,
653- body ,
654- [
636+ def body (
655637 i ,
656638 batch_size ,
657639 outputs ,
658640 encoder_masks ,
659641 encoder_hidden_states ,
660642 durations_gt ,
661643 max_durations ,
662- ],
663- shape_invariants = [
664- i .get_shape (),
665- batch_size .get_shape (),
666- tf .TensorShape ([None , None , self .config .hidden_size ]),
667- tf .TensorShape ([None , None ]),
668- encoder_hidden_states .get_shape (),
669- durations_gt .get_shape (),
670- max_durations .get_shape (),
671- ],
672- )
644+ ):
645+ repeats = durations_gt [i ]
646+ real_length = tf .reduce_sum (repeats )
647+ pad_size = max_durations - real_length
648+ masks = tf .sequence_mask ([real_length ], max_durations , dtype = tf .int32 )
649+ repeat_encoder_hidden_states = tf .repeat (
650+ encoder_hidden_states [i ], repeats = repeats , axis = 0
651+ )
652+ repeat_encoder_hidden_states = tf .expand_dims (
653+ tf .pad (repeat_encoder_hidden_states , [[0 , pad_size ], [0 , 0 ]]), 0
654+ ) # [1, max_durations, hidden_size]
655+ outputs = tf .concat ([outputs , repeat_encoder_hidden_states ], axis = 0 )
656+ encoder_masks = tf .concat ([encoder_masks , masks ], axis = 0 )
657+ return [
658+ i + 1 ,
659+ batch_size ,
660+ outputs ,
661+ encoder_masks ,
662+ encoder_hidden_states ,
663+ durations_gt ,
664+ max_durations ,
665+ ]
666+
667+ # initialize iteration i.
668+ i = tf .constant (0 , dtype = tf .int32 )
669+ _ , _ , outputs , encoder_masks , _ , _ , _ , = tf .while_loop (
670+ condition ,
671+ body ,
672+ [
673+ i ,
674+ batch_size ,
675+ outputs ,
676+ encoder_masks ,
677+ encoder_hidden_states ,
678+ durations_gt ,
679+ max_durations ,
680+ ],
681+ shape_invariants = [
682+ i .get_shape (),
683+ batch_size .get_shape (),
684+ tf .TensorShape ([None , None , self .config .hidden_size ]),
685+ tf .TensorShape ([None , None ]),
686+ encoder_hidden_states .get_shape (),
687+ durations_gt .get_shape (),
688+ max_durations .get_shape (),
689+ ],
690+ )
673691
674692 return outputs , encoder_masks
675693
676694
677695class TFFastSpeech (tf .keras .Model ):
678696 """TF Fastspeech module."""
679697
680- def __init__ (self , config , ** kwargs ):
698+ def __init__ (self , config , enable_tflite_convertible = False , ** kwargs ):
681699 """Init layers for fastspeech."""
682700 super ().__init__ (** kwargs )
683701 self .embeddings = TFFastSpeechEmbeddings (config , name = "embeddings" )
@@ -686,12 +704,16 @@ def __init__(self, config, **kwargs):
686704 config , name = "duration_predictor"
687705 )
688706 self .length_regulator = TFFastSpeechLengthRegulator (
689- config , name = "length_regulator"
707+ config ,
708+ enable_tflite_convertible = enable_tflite_convertible ,
709+ name = "length_regulator"
690710 )
691711 self .decoder = TFFastSpeechDecoder (config , name = "decoder" )
692712 self .mel_dense = tf .keras .layers .Dense (units = config .num_mels , name = "mel_before" )
693713 self .postnet = TFTacotronPostnet (config = config , name = "postnet" )
694714
715+ self .enable_tflite_convertible = enable_tflite_convertible
716+
695717 def _build (self ):
696718 """Dummy input for building model."""
697719 # fake inputs
@@ -749,8 +771,8 @@ def call(
749771 input_signature = [
750772 tf .TensorSpec (shape = [None , None ], dtype = tf .int32 ),
751773 tf .TensorSpec (shape = [None , None ], dtype = tf .bool ),
752- tf .TensorSpec (shape = [None ,], dtype = tf .int32 ),
753- tf .TensorSpec (shape = [None ,], dtype = tf .float32 ),
774+ tf .TensorSpec (shape = [None , ], dtype = tf .int32 ),
775+ tf .TensorSpec (shape = [None , ], dtype = tf .float32 ),
754776 ],
755777 )
756778 def inference (self , input_ids , attention_mask , speaker_ids , speed_ratios ):
@@ -799,3 +821,59 @@ def inference(self, input_ids, attention_mask, speaker_ids, speed_ratios):
799821
800822 outputs = (mel_before , mel_after , duration_outputs )
801823 return outputs
824+
825+ @tf .function (
826+ experimental_relax_shapes = True ,
827+ input_signature = [
828+ tf .TensorSpec (shape = [1 , None ], dtype = tf .int32 ),
829+ tf .TensorSpec (shape = [1 , None ], dtype = tf .bool ),
830+ tf .TensorSpec (shape = [1 , ], dtype = tf .int32 ),
831+ tf .TensorSpec (shape = [1 , ], dtype = tf .float32 ),
832+ ],
833+ )
834+ def inference_tflite (self , input_ids , attention_mask , speaker_ids , speed_ratios ):
835+ """Call logic."""
836+ embedding_output = self .embeddings ([input_ids , speaker_ids ], training = False )
837+ encoder_output = self .encoder (
838+ [embedding_output , attention_mask ], training = False
839+ )
840+ last_encoder_hidden_states = encoder_output [0 ]
841+
842+ # duration predictor, here use last_encoder_hidden_states, u can use more hidden_states layers
843+ # rather than just use last_hidden_states of encoder for duration_predictor.
844+ duration_outputs = self .duration_predictor (
845+ [last_encoder_hidden_states , attention_mask ]
846+ ) # [batch_size, length]
847+ duration_outputs = tf .math .exp (duration_outputs ) - 1.0
848+
849+ if speed_ratios is None :
850+ speed_ratios = tf .convert_to_tensor (np .array ([1.0 ]), dtype = tf .float32 )
851+
852+ duration_outputs = tf .cast (
853+ tf .math .round (duration_outputs * speed_ratios ), tf .int32
854+ )
855+
856+ length_regulator_outputs , encoder_masks = self .length_regulator (
857+ [last_encoder_hidden_states , duration_outputs ], training = False
858+ )
859+
860+ # create decoder positional embedding
861+ decoder_pos = tf .range (
862+ 1 , tf .shape (length_regulator_outputs )[1 ] + 1 , dtype = tf .int32
863+ )
864+ masked_decoder_pos = tf .expand_dims (decoder_pos , 0 ) * encoder_masks
865+
866+ decoder_output = self .decoder (
867+ [length_regulator_outputs , speaker_ids , encoder_masks , masked_decoder_pos ],
868+ training = False ,
869+ )
870+ last_decoder_hidden_states = decoder_output [0 ]
871+
872+ # here u can use sum or concat more than 1 hidden states layers from decoder.
873+ mel_before = self .mel_dense (last_decoder_hidden_states )
874+ mel_after = (
875+ self .postnet ([mel_before , encoder_masks ], training = False ) + mel_before
876+ )
877+
878+ outputs = (mel_before , mel_after , duration_outputs )
879+ return outputs
0 commit comments