Skip to content

Commit 67e123a

Browse files
authored
Merge branch 'master' into master
2 parents 1c554e3 + 47f3146 commit 67e123a

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

examples/multiband_melgan/train_multiband_melgan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def main():
363363
if config["remove_short_samples"]:
364364
mel_length_threshold = config["batch_max_steps"] // config[
365365
"hop_size"
366-
] + 2 * config["multiband_melgan_generator"].get("aux_context_window", 0)
366+
] + 2 * config["multiband_melgan_generator_params"].get("aux_context_window", 0)
367367
else:
368368
mel_length_threshold = None
369369

@@ -427,7 +427,7 @@ def main():
427427
with STRATEGY.scope():
428428
# define generator and discriminator
429429
generator = TFMelGANGenerator(
430-
MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator"]),
430+
MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator_params"]),
431431
name="multi_band_melgan_generator",
432432
)
433433

@@ -437,7 +437,7 @@ def main():
437437
)
438438

439439
pqmf = TFPQMF(
440-
MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator"]), name="pqmf"
440+
MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator_params"]), name="pqmf"
441441
)
442442

443443
# dummy input to build model.

tensorflow_tts/models/tacotron2.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,19 @@ def __init__(self, config, **kwargs):
210210
name="bilstm",
211211
)
212212

213+
if config.n_speakers > 1:
214+
self.encoder_speaker_embeddings = tf.keras.layers.Embedding(
215+
config.n_speakers,
216+
config.embedding_hidden_size,
217+
embeddings_initializer=get_initializer(config.initializer_range),
218+
name="encoder_speaker_embeddings",
219+
)
220+
self.encoder_speaker_fc = tf.keras.layers.Dense(
221+
units=config.encoder_lstm_units * 2, name="encoder_speaker_fc"
222+
)
223+
224+
self.config = config
225+
213226
def call(self, inputs, training=False):
214227
"""Call logic."""
215228
input_ids, speaker_ids, input_mask = inputs
@@ -224,6 +237,18 @@ def call(self, inputs, training=False):
224237
# bi-lstm.
225238
outputs = self.bilstm(conv_outputs, mask=input_mask)
226239

240+
if self.config.n_speakers > 1:
241+
encoder_speaker_embeddings = self.encoder_speaker_embeddings(speaker_ids)
242+
encoder_speaker_features = tf.math.softplus(
243+
self.encoder_speaker_fc(encoder_speaker_embeddings)
244+
)
245+
# extended encoderspeaker embeddings
246+
extended_encoder_speaker_features = encoder_speaker_features[
247+
:, tf.newaxis, :
248+
]
249+
# sum to encoder outputs
250+
outputs += extended_encoder_speaker_features
251+
227252
return outputs
228253

229254

0 commit comments

Comments
 (0)