Skip to content

Commit 47f3146

Browse files
committed
🐸 Add speaker embedding for output of encoder features.
1 parent ca9f678 commit 47f3146

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

tensorflow_tts/models/tacotron2.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import numpy as np
2121
import tensorflow as tf
22+
2223
# TODO: once https://github.com/tensorflow/addons/pull/1964 is fixed,
2324
# uncomment this line.
2425
# from tensorflow_addons.seq2seq import dynamic_decode
@@ -209,6 +210,19 @@ def __init__(self, config, **kwargs):
209210
name="bilstm",
210211
)
211212

213+
if config.n_speakers > 1:
214+
self.encoder_speaker_embeddings = tf.keras.layers.Embedding(
215+
config.n_speakers,
216+
config.embedding_hidden_size,
217+
embeddings_initializer=get_initializer(config.initializer_range),
218+
name="encoder_speaker_embeddings",
219+
)
220+
self.encoder_speaker_fc = tf.keras.layers.Dense(
221+
units=config.encoder_lstm_units * 2, name="encoder_speaker_fc"
222+
)
223+
224+
self.config = config
225+
212226
def call(self, inputs, training=False):
213227
"""Call logic."""
214228
input_ids, speaker_ids, input_mask = inputs
@@ -223,6 +237,18 @@ def call(self, inputs, training=False):
223237
# bi-lstm.
224238
outputs = self.bilstm(conv_outputs, mask=input_mask)
225239

240+
if self.config.n_speakers > 1:
241+
encoder_speaker_embeddings = self.encoder_speaker_embeddings(speaker_ids)
242+
encoder_speaker_features = tf.math.softplus(
243+
self.encoder_speaker_fc(encoder_speaker_embeddings)
244+
)
245+
# extended encoderspeaker embeddings
246+
extended_encoder_speaker_features = encoder_speaker_features[
247+
:, tf.newaxis, :
248+
]
249+
# sum to encoder outputs
250+
outputs += extended_encoder_speaker_features
251+
226252
return outputs
227253

228254

@@ -831,7 +857,9 @@ def call(
831857

832858
if self.enable_tflite_convertible:
833859
mask = tf.math.not_equal(
834-
tf.cast(tf.reduce_sum(tf.abs(decoder_outputs), axis=-1), dtype=tf.int32),
860+
tf.cast(
861+
tf.reduce_sum(tf.abs(decoder_outputs), axis=-1), dtype=tf.int32
862+
),
835863
0,
836864
)
837865
decoder_outputs = tf.expand_dims(
@@ -982,7 +1010,9 @@ def inference_tflite(self, input_ids, input_lengths, speaker_ids, **kwargs):
9821010

9831011
if self.enable_tflite_convertible:
9841012
mask = tf.math.not_equal(
985-
tf.cast(tf.reduce_sum(tf.abs(decoder_outputs), axis=-1), dtype=tf.int32),
1013+
tf.cast(
1014+
tf.reduce_sum(tf.abs(decoder_outputs), axis=-1), dtype=tf.int32
1015+
),
9861016
0,
9871017
)
9881018
decoder_outputs = tf.expand_dims(

0 commit comments

Comments
 (0)