Skip to content

Commit 4d45d7d

Browse files
committed
🐸 eos_id now add in the end of sentence automatically, remove all explicit add eos_id.
1 parent 8a5f63a commit 4d45d7d

File tree

6 files changed

+14
-20
lines changed

6 files changed

+14
-20
lines changed

‎examples/tacotron2/extract_duration.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def main():
165165

166166
for i, alignment in enumerate(alignment_historys):
167167
real_char_length = (
168-
input_lengths[i].numpy() - 1
169-
) # minus 1 because char have eos tokens.
168+
input_lengths[i].numpy()
169+
)
170170
real_mel_length = real_mel_lengths[i].numpy()
171171
alignment_mel_length = int(np.ceil(real_mel_length / config["tacotron2_params"]["reduction_factor"]))
172172
alignment = alignment[:real_char_length, :alignment_mel_length]

‎examples/tacotron2/tacotron_dataset.py‎

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
import tensorflow as tf
2424

2525
from tensorflow_tts.datasets.abstract_dataset import AbstractDataset
26-
from tensorflow_tts.processor.ljspeech import symbols as ljspeech_symbols
27-
from tensorflow_tts.utils.korean import symbols as kss_symbols
28-
from tensorflow_tts.processor.baker import symbols as baker_symbols
2926
from tensorflow_tts.utils import find_files
3027

3128

@@ -103,13 +100,6 @@ def __init__(
103100
suffix = charactor_query[1:]
104101
utt_ids = [os.path.basename(f).replace(suffix, "") for f in charactor_files]
105102

106-
eos_token_dict = {
107-
"ljspeech": len(ljspeech_symbols) - 1,
108-
"kss": len(kss_symbols) - 1,
109-
"baker": len(baker_symbols) - 1
110-
}
111-
self.eos_token_id = eos_token_dict[dataset]
112-
113103
# set global params
114104
self.utt_ids = utt_ids
115105
self.mel_files = mel_files
@@ -125,7 +115,7 @@ def __init__(
125115
self.ga_pad_value = ga_pad_value
126116
self.g = g
127117
self.use_fixed_shapes = use_fixed_shapes
128-
self.max_char_length = np.max(char_lengths) + 1 # +1 for eos
118+
self.max_char_length = np.max(char_lengths)
129119

130120
if np.max(mel_lengths) % self.reduction_factor == 0:
131121
self.max_mel_length = np.max(mel_lengths)
@@ -148,10 +138,6 @@ def generator(self, utt_ids):
148138
mel_length = self.mel_lengths[i]
149139
char_length = self.char_lengths[i]
150140

151-
# add eos token for charactor since charactor is original token.
152-
charactor = np.concatenate([charactor, [self.eos_token_id]], -1)
153-
char_length += 1
154-
155141
# padding mel to make its length is multiple of reduction factor.
156142
real_mel_length = mel_length
157143
remainder = mel_length % self.reduction_factor

‎tensorflow_tts/configs/fastspeech.py‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tensorflow_tts.processor.ljspeech import LJSPEECH_SYMBOLS as lj_symbols
2020
from tensorflow_tts.processor.kss import KSS_SYMBOLS as kss_symbols
2121
from tensorflow_tts.processor.baker import BAKER_SYMBOLS as bk_symbols
22+
from tensorflow_tts.processor.libritts import LIBRITTS_SYMBOLS as lbri_symbols
2223

2324

2425
SelfAttentionParams = collections.namedtuple(
@@ -91,6 +92,8 @@ def __init__(
9192
self.vocab_size = len(kss_symbols)
9293
elif dataset == "baker":
9394
self.vocab_size = len(bk_symbols)
95+
elif dataset == "libritts":
96+
self.vocab_size = len(lbri_symbols)
9497
else:
9598
raise ValueError("No such dataset: {}".format(dataset))
9699
self.initializer_range = initializer_range

‎tensorflow_tts/configs/tacotron2.py‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from tensorflow_tts.processor.ljspeech import LJSPEECH_SYMBOLS as lj_symbols
1818
from tensorflow_tts.processor.kss import KSS_SYMBOLS as kss_symbols
1919
from tensorflow_tts.processor.baker import BAKER_SYMBOLS as bk_symbols
20+
from tensorflow_tts.processor.libritts import LIBRITTS_SYMBOLS as lbri_symbols
2021

2122

2223
class Tacotron2Config(object):
@@ -61,6 +62,8 @@ def __init__(
6162
self.vocab_size = len(kss_symbols)
6263
elif dataset == 'baker':
6364
self.vocab_size = len(bk_symbols)
65+
elif dataset == "libritts":
66+
self.vocab_size = len(lbri_symbols)
6467
else:
6568
raise ValueError("No such dataset: {}".format(dataset))
6669
self.embedding_hidden_size = embedding_hidden_size

‎tensorflow_tts/processor/base_processor.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __getattr__(self, name: str) -> Union[str, int]:
7272

7373
def create_speaker_map(self):
7474
"""
75-
Create speaker map for dataset
75+
Create speaker map for dataset.
7676
"""
7777
sp_id = 0
7878
for i in self.items:
@@ -94,7 +94,8 @@ def create_symbols(self):
9494
def create_items(self):
9595
"""
9696
Method used to create items from training file
97-
items struct => text, wav_file_path, speaker_name
97+
items struct example => text, wav_file_path, speaker_name.
98+
Note that the speaker_name should be a last.
9899
"""
99100
with open(
100101
os.path.join(self.data_dir, self.train_f_name), mode="r", encoding="utf-8"

‎test/files/mapper.json‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
"0": "a",
1313
"1": "b",
1414
"2": "@ph"
15-
}
15+
},
16+
"processor_name": "TestProcessor"
1617
}

0 commit comments

Comments
 (0)