1919
2020import numpy as np
2121import tensorflow as tf
22+
2223# TODO: once https://github.com/tensorflow/addons/pull/1964 is fixed,
2324# uncomment this line.
2425# from tensorflow_addons.seq2seq import dynamic_decode
@@ -209,6 +210,19 @@ def __init__(self, config, **kwargs):
209210 name = "bilstm" ,
210211 )
211212
213+ if config .n_speakers > 1 :
214+ self .encoder_speaker_embeddings = tf .keras .layers .Embedding (
215+ config .n_speakers ,
216+ config .embedding_hidden_size ,
217+ embeddings_initializer = get_initializer (config .initializer_range ),
218+ name = "encoder_speaker_embeddings" ,
219+ )
220+ self .encoder_speaker_fc = tf .keras .layers .Dense (
221+ units = config .encoder_lstm_units * 2 , name = "encoder_speaker_fc"
222+ )
223+
224+ self .config = config
225+
212226 def call (self , inputs , training = False ):
213227 """Call logic."""
214228 input_ids , speaker_ids , input_mask = inputs
@@ -223,6 +237,18 @@ def call(self, inputs, training=False):
223237 # bi-lstm.
224238 outputs = self .bilstm (conv_outputs , mask = input_mask )
225239
240+ if self .config .n_speakers > 1 :
241+ encoder_speaker_embeddings = self .encoder_speaker_embeddings (speaker_ids )
242+ encoder_speaker_features = tf .math .softplus (
243+ self .encoder_speaker_fc (encoder_speaker_embeddings )
244+ )
245+ # extended encoderspeaker embeddings
246+ extended_encoder_speaker_features = encoder_speaker_features [
247+ :, tf .newaxis , :
248+ ]
249+ # sum to encoder outputs
250+ outputs += extended_encoder_speaker_features
251+
226252 return outputs
227253
228254
@@ -831,7 +857,9 @@ def call(
831857
832858 if self .enable_tflite_convertible :
833859 mask = tf .math .not_equal (
834- tf .cast (tf .reduce_sum (tf .abs (decoder_outputs ), axis = - 1 ), dtype = tf .int32 ),
860+ tf .cast (
861+ tf .reduce_sum (tf .abs (decoder_outputs ), axis = - 1 ), dtype = tf .int32
862+ ),
835863 0 ,
836864 )
837865 decoder_outputs = tf .expand_dims (
@@ -982,7 +1010,9 @@ def inference_tflite(self, input_ids, input_lengths, speaker_ids, **kwargs):
9821010
9831011 if self .enable_tflite_convertible :
9841012 mask = tf .math .not_equal (
985- tf .cast (tf .reduce_sum (tf .abs (decoder_outputs ), axis = - 1 ), dtype = tf .int32 ),
1013+ tf .cast (
1014+ tf .reduce_sum (tf .abs (decoder_outputs ), axis = - 1 ), dtype = tf .int32
1015+ ),
9861016 0 ,
9871017 )
9881018 decoder_outputs = tf .expand_dims (
0 commit comments