2323import tensorflow as tf
2424
2525from tensorflow_tts .datasets .abstract_dataset import AbstractDataset
26- from tensorflow_tts .processor .ljspeech import symbols
26+ from tensorflow_tts .processor .ljspeech import symbols as ljspeech_symbols
27+ from tensorflow_tts .utils .korean import symbols as kss_symbols
28+ from tensorflow_tts .processor .baker import symbols as baker_symbols
2729from tensorflow_tts .utils import find_files
2830
2931
@@ -51,6 +53,7 @@ class CharactorMelDataset(AbstractDataset):
5153
5254 def __init__ (
5355 self ,
56+ dataset ,
5457 root_dir ,
5558 charactor_query = "*-ids.npy" ,
5659 mel_query = "*-norm-feats.npy" ,
@@ -100,6 +103,13 @@ def __init__(
100103 suffix = charactor_query [1 :]
101104 utt_ids = [os .path .basename (f ).replace (suffix , "" ) for f in charactor_files ]
102105
106+ eos_token_dict = {
107+ "ljspeech" : len (ljspeech_symbols ) - 1 ,
108+ "kss" : len (kss_symbols ) - 1 ,
109+ "baker" : len (baker_symbols ) - 1
110+ }
111+ self .eos_token_id = eos_token_dict [dataset ]
112+
103113 # set global params
104114 self .utt_ids = utt_ids
105115 self .mel_files = mel_files
@@ -139,10 +149,11 @@ def generator(self, utt_ids):
139149 char_length = self .char_lengths [i ]
140150
141151 # add eos token for charactor since charactor is original token.
142- charactor = np .concatenate ([charactor , [len ( symbols ) - 1 ]], - 1 )
152+ charactor = np .concatenate ([charactor , [self . eos_token_id ]], - 1 )
143153 char_length += 1
144154
145155 # padding mel to make its length is multiple of reduction factor.
156+ real_mel_length = mel_length
146157 remainder = mel_length % self .reduction_factor
147158 if remainder != 0 :
148159 new_mel_length = mel_length + self .reduction_factor - remainder
@@ -169,6 +180,7 @@ def generator(self, utt_ids):
169180 "speaker_ids" : 0 ,
170181 "mel_gts" : mel ,
171182 "mel_lengths" : mel_length ,
183+ "real_mel_lengths" : real_mel_length ,
172184 "g_attentions" : g_attention ,
173185 }
174186
@@ -209,6 +221,7 @@ def create(
209221 "speaker_ids" : 0 ,
210222 "mel_gts" : self .mel_pad_value ,
211223 "mel_lengths" : 0 ,
224+ "real_mel_lengths" : 0 ,
212225 "g_attentions" : self .ga_pad_value ,
213226 }
214227
@@ -224,6 +237,7 @@ def create(
224237 if self .use_fixed_shapes is False
225238 else [self .max_mel_length , 80 ],
226239 "mel_lengths" : [],
240+ "real_mel_lengths" : [],
227241 "g_attentions" : [None , None ]
228242 if self .use_fixed_shapes is False
229243 else [self .max_char_length , self .max_mel_length // self .reduction_factor ],
@@ -243,6 +257,7 @@ def get_output_dtypes(self):
243257 "speaker_ids" : tf .int32 ,
244258 "mel_gts" : tf .float32 ,
245259 "mel_lengths" : tf .int32 ,
260+ "real_mel_lengths" : tf .int32 ,
246261 "g_attentions" : tf .float32 ,
247262 }
248263 return output_types
0 commit comments