Skip to content

Commit 27004f4

Browse files
committed
update chinese example
1 parent 93ec250 commit 27004f4

File tree

6 files changed

+46
-7
lines changed

6 files changed

+46
-7
lines changed

examples/multiband_melgan/decode_mb_melgan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,14 @@ def main():
110110

111111
# define model and load checkpoint
112112
mb_melgan = TFMelGANGenerator(
113-
config=MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator"]),
113+
config=MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator_params"]),
114114
name="multiband_melgan_generator",
115115
)
116116
mb_melgan._build()
117117
mb_melgan.load_weights(args.checkpoint)
118118

119119
pqmf = TFPQMF(
120-
config=MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator"]), name="pqmf"
120+
config=MultiBandMelGANGeneratorConfig(**config["multiband_melgan_generator_params"]), name="pqmf"
121121
)
122122

123123
for data in tqdm(dataset, desc="[Decoding]"):

examples/tacotron2/conf/tacotron2.v1.baker.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ tacotron2_params:
3232
prenet_activation: 'relu'
3333
prenet_dropout_rate: 0.5
3434
n_lstm_decoder: 1
35-
reduction_factor: 1
35+
reduction_factor: 2
3636
decoder_lstm_units: 1024
3737
attention_dim: 128
3838
attention_filters: 32

examples/tacotron2/decode_tacotron2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,13 @@ def main():
110110

111111
# define data-loader
112112
dataset = CharactorMelDataset(
113+
dataset=config["tacotron2_params"]["dataset"],
113114
root_dir=args.rootdir,
114115
charactor_query=char_query,
115116
mel_query=mel_query,
116117
charactor_load_fn=char_load_fn,
117118
mel_load_fn=mel_load_fn,
119+
reduction_factor=config["tacotron2_params"]["reduction_factor"]
118120
)
119121
dataset = dataset.create(allow_cache=True, batch_size=args.batch_size)
120122

examples/tacotron2/extract_duration.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,13 @@ def main():
124124

125125
# define data-loader
126126
dataset = CharactorMelDataset(
127+
dataset=config["tacotron2_params"]["dataset"],
127128
root_dir=args.rootdir,
128129
charactor_query=char_query,
129130
mel_query=mel_query,
130131
charactor_load_fn=char_load_fn,
131132
mel_load_fn=mel_load_fn,
133+
reduction_factor=config["tacotron2_params"]["reduction_factor"]
132134
)
133135
dataset = dataset.create(allow_cache=True, batch_size=args.batch_size)
134136

@@ -146,6 +148,8 @@ def main():
146148
input_lengths = data["input_lengths"]
147149
mel_lengths = data["mel_lengths"]
148150
utt_ids = utt_ids.numpy()
151+
real_mel_lengths = data["real_mel_lengths"]
152+
del data["real_mel_lengths"]
149153

150154
# tacotron2 inference.
151155
mel_outputs, post_mel_outputs, stop_outputs, alignment_historys = tacotron2(
@@ -163,10 +167,26 @@ def main():
163167
real_char_length = (
164168
input_lengths[i].numpy() - 1
165169
) # minus 1 because char have eos tokens.
166-
real_mel_length = mel_lengths[i].numpy()
167-
alignment = alignment[:real_char_length, :real_mel_length]
170+
real_mel_length = real_mel_lengths[i].numpy()
171+
alignment_mel_length = int(np.ceil(real_mel_length / config["tacotron2_params"]["reduction_factor"]))
172+
alignment = alignment[:real_char_length, :alignment_mel_length]
168173
d = get_duration_from_alignment(alignment) # [max_char_len]
169174

175+
d = d * config["tacotron2_params"]["reduction_factor"]
176+
assert np.sum(d) >= real_mel_length, f"{d}, {np.sum(d)}, {alignment_mel_length}, {real_mel_length}"
177+
if np.sum(d) > real_mel_length:
178+
rest = np.sum(d) - real_mel_length
179+
# print(d, np.sum(d), real_mel_length)
180+
if d[-1] > rest:
181+
d[-1] -= rest
182+
elif d[0] > rest:
183+
d[0] -= rest
184+
else:
185+
d[-1] -= rest // 2
186+
d[0] -= (rest - rest // 2)
187+
188+
assert d[-1] > 0 and d[0] > 0, f'{d}, {np.sum(d)}, {real_mel_length}'
189+
170190
saved_name = utt_ids[i].decode("utf-8")
171191

172192
# check a length compatible

examples/tacotron2/tacotron_dataset.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
import tensorflow as tf
2424

2525
from tensorflow_tts.datasets.abstract_dataset import AbstractDataset
26-
from tensorflow_tts.processor.ljspeech import symbols
26+
from tensorflow_tts.processor.ljspeech import symbols as ljspeech_symbols
27+
from tensorflow_tts.utils.korean import symbols as kss_symbols
28+
from tensorflow_tts.processor.baker import symbols as baker_symbols
2729
from tensorflow_tts.utils import find_files
2830

2931

@@ -51,6 +53,7 @@ class CharactorMelDataset(AbstractDataset):
5153

5254
def __init__(
5355
self,
56+
dataset,
5457
root_dir,
5558
charactor_query="*-ids.npy",
5659
mel_query="*-norm-feats.npy",
@@ -100,6 +103,13 @@ def __init__(
100103
suffix = charactor_query[1:]
101104
utt_ids = [os.path.basename(f).replace(suffix, "") for f in charactor_files]
102105

106+
eos_token_dict = {
107+
"ljspeech": len(ljspeech_symbols) - 1,
108+
"kss": len(kss_symbols) - 1,
109+
"baker": len(baker_symbols) - 1
110+
}
111+
self.eos_token_id = eos_token_dict[dataset]
112+
103113
# set global params
104114
self.utt_ids = utt_ids
105115
self.mel_files = mel_files
@@ -139,10 +149,11 @@ def generator(self, utt_ids):
139149
char_length = self.char_lengths[i]
140150

141151
# add eos token for charactor since charactor is original token.
142-
charactor = np.concatenate([charactor, [len(symbols) - 1]], -1)
152+
charactor = np.concatenate([charactor, [self.eos_token_id]], -1)
143153
char_length += 1
144154

145155
# padding mel to make its length is multiple of reduction factor.
156+
real_mel_length = mel_length
146157
remainder = mel_length % self.reduction_factor
147158
if remainder != 0:
148159
new_mel_length = mel_length + self.reduction_factor - remainder
@@ -169,6 +180,7 @@ def generator(self, utt_ids):
169180
"speaker_ids": 0,
170181
"mel_gts": mel,
171182
"mel_lengths": mel_length,
183+
"real_mel_lengths": real_mel_length,
172184
"g_attentions": g_attention,
173185
}
174186

@@ -209,6 +221,7 @@ def create(
209221
"speaker_ids": 0,
210222
"mel_gts": self.mel_pad_value,
211223
"mel_lengths": 0,
224+
"real_mel_lengths": 0,
212225
"g_attentions": self.ga_pad_value,
213226
}
214227

@@ -224,6 +237,7 @@ def create(
224237
if self.use_fixed_shapes is False
225238
else [self.max_mel_length, 80],
226239
"mel_lengths": [],
240+
"real_mel_lengths": [],
227241
"g_attentions": [None, None]
228242
if self.use_fixed_shapes is False
229243
else [self.max_char_length, self.max_mel_length // self.reduction_factor],
@@ -243,6 +257,7 @@ def get_output_dtypes(self):
243257
"speaker_ids": tf.int32,
244258
"mel_gts": tf.float32,
245259
"mel_lengths": tf.int32,
260+
"real_mel_lengths": tf.int32,
246261
"g_attentions": tf.float32,
247262
}
248263
return output_types

examples/tacotron2/train_tacotron2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def main():
368368
raise ValueError("Only npy are supported.")
369369

370370
train_dataset = CharactorMelDataset(
371+
dataset=config["tacotron2_params"]["dataset"],
371372
root_dir=args.train_dir,
372373
charactor_query=charactor_query,
373374
mel_query=mel_query,
@@ -394,6 +395,7 @@ def main():
394395
)
395396

396397
valid_dataset = CharactorMelDataset(
398+
dataset=config["tacotron2_params"]["dataset"],
397399
root_dir=args.dev_dir,
398400
charactor_query=charactor_query,
399401
mel_query=mel_query,

0 commit comments

Comments
 (0)