Skip to content

Commit 138f11c

Browse files
committed
💉 Add libritts processor.
1 parent 35ca768 commit 138f11c

File tree

5 files changed

+244
-15
lines changed

5 files changed

+244
-15
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ ljspeech
3333
/examples/tacotron2/exp/
3434
/temp/
3535
kss
36-
36+
LibriTTS

tensorflow_tts/bin/preprocess.py

Lines changed: 126 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@
3232

3333
from tensorflow_tts.processor import LJSpeechProcessor
3434
from tensorflow_tts.processor import KSSProcessor
35+
from tensorflow_tts.processor import LibriTTSProcessor
3536

3637
from tensorflow_tts.processor.ljspeech import LJSPEECH_SYMBOLS
3738
from tensorflow_tts.processor.kss import KSS_SYMBOLS
39+
from tensorflow_tts.processor.libritts import LIBRITTS_SYMBOLS
3840

3941
from tensorflow_tts.utils import remove_outlier
4042

@@ -65,7 +67,7 @@ def parse_and_config():
6567
"--dataset",
6668
type=str,
6769
default="ljspeech",
68-
choices=["ljspeech", "kss"],
70+
choices=["ljspeech", "kss", "libritts"],
6971
help="Dataset to preprocess.",
7072
)
7173
parser.add_argument(
@@ -107,6 +109,64 @@ def parse_and_config():
107109
return config
108110

109111

112+
def ph_based_trim(
113+
config,
114+
utt_id: str,
115+
text_ids: np.array,
116+
raw_text: str,
117+
audio: np.array,
118+
hop_size: int,
119+
) -> (bool, np.array, np.array):
120+
"""
121+
Args:
122+
config: Parsed yaml config
123+
utt_id: file name
124+
text_ids: array with text ids
125+
raw_text: raw text of file
126+
audio: parsed wav file
127+
hop_size: Hop size
128+
Returns: (bool, np.array, np.array) => if trimmed return True, new text_ids, new audio_array
129+
"""
130+
131+
os.makedirs(os.path.join(config["rootdir"], "trimmed-durations"), exist_ok=True)
132+
duration_path = config.get(
133+
"duration_path", os.path.join(config["rootdir"], "durations")
134+
)
135+
duration_fixed_path = config.get(
136+
"duration_fixed_path", os.path.join(config["rootdir"], "trimmed-durations")
137+
)
138+
sil_ph = ["SIL", "END"] # TODO FIX hardcoded values
139+
text = raw_text.split(" ")
140+
141+
trim_start, trim_end = False, False
142+
143+
if text[0] in sil_ph:
144+
trim_start = True
145+
146+
if text[-1] in sil_ph:
147+
trim_end = True
148+
149+
if not trim_start and not trim_end:
150+
return False, text_ids, audio
151+
152+
idx_start, idx_end = (
153+
0 if not trim_start else 1,
154+
text_ids.__len__() if not trim_end else -1,
155+
)
156+
text_ids = text_ids[idx_start:idx_end]
157+
durations = np.load(os.path.join(duration_path, f"{utt_id}-durations.npy"))
158+
if trim_start:
159+
s_trim = int(durations[0] * hop_size)
160+
audio = audio[s_trim:]
161+
if trim_end:
162+
e_trim = int(durations[-1] * hop_size)
163+
audio = audio[:-e_trim]
164+
165+
durations = durations[idx_start:idx_end]
166+
np.save(os.path.join(duration_fixed_path, f"{utt_id}-durations.npy"), durations)
167+
return True, text_ids, audio
168+
169+
110170
def gen_audio_features(item, config):
111171
"""Generate audio features and transformations
112172
Args:
@@ -132,12 +192,29 @@ def gen_audio_features(item, config):
132192

133193
# trim silence
134194
if config["trim_silence"]:
135-
audio, _ = librosa.effects.trim(
136-
audio,
137-
top_db=config["trim_threshold_in_db"],
138-
frame_length=config["trim_frame_size"],
139-
hop_length=config["trim_hop_size"],
140-
)
195+
if "trim_mfa" in config and config["trim_mfa"]:
196+
_, item["text_ids"], audio = ph_based_trim(
197+
config,
198+
utt_id,
199+
item["text_ids"],
200+
item["raw_text"],
201+
audio,
202+
config["hop_size"],
203+
)
204+
if (
205+
audio.__len__() < 1
206+
): # very short files can get trimmed fully if mfa didnt extract any tokens LibriTTS maybe take only longer files?
207+
logging.warning(
208+
f"File have only silence or MFA didnt extract any token {utt_id}"
209+
)
210+
return False, None, None, None, item
211+
else:
212+
audio, _ = librosa.effects.trim(
213+
audio,
214+
top_db=config["trim_threshold_in_db"],
215+
frame_length=config["trim_frame_size"],
216+
hop_length=config["trim_hop_size"],
217+
)
141218

142219
# resample audio if necessary
143220
if "sampling_rate_for_feats" in config:
@@ -207,7 +284,7 @@ def gen_audio_features(item, config):
207284
item["mel"] = mel
208285
item["f0"] = f0
209286
item["energy"] = energy
210-
return mel, energy, f0, item
287+
return True, mel, energy, f0, item
211288

212289

213290
def save_statistics_to_file(scaler_list, config):
@@ -261,14 +338,20 @@ def preprocess():
261338
dataset_processor = {
262339
"ljspeech": LJSpeechProcessor,
263340
"kss": KSSProcessor,
341+
"libritts": LibriTTSProcessor,
264342
}
265343

266344
dataset_symbol = {
267345
"ljspeech": LJSPEECH_SYMBOLS,
268346
"kss": KSS_SYMBOLS,
347+
"libritts": LIBRITTS_SYMBOLS,
269348
}
270349

271-
dataset_cleaner = {"ljspeech": "english_cleaners", "kss": "korean_cleaners"}
350+
dataset_cleaner = {
351+
"ljspeech": "english_cleaners",
352+
"kss": "korean_cleaners",
353+
"libritts": None,
354+
}
272355

273356
logging.info(f"Selected '{config['dataset']}' processor.")
274357
processor = dataset_processor[config["dataset"]](
@@ -291,9 +374,21 @@ def preprocess():
291374
)
292375

293376
# build train test split
294-
train_split, valid_split = train_test_split(
295-
processor.items, test_size=config["test_size"], random_state=42, shuffle=True,
296-
)
377+
if config["dataset"] == "libritts":
378+
train_split, valid_split, _, _ = train_test_split(
379+
processor.items,
380+
[i[-1] for i in processor.items],
381+
test_size=config["test_size"],
382+
random_state=42,
383+
shuffle=True,
384+
)
385+
else:
386+
train_split, valid_split = train_test_split(
387+
processor.items,
388+
test_size=config["test_size"],
389+
random_state=42,
390+
shuffle=True,
391+
)
297392
logging.info(f"Training items: {len(train_split)}")
298393
logging.info(f"Validation items: {len(valid_split)}")
299394

@@ -327,15 +422,33 @@ def iterator_data(items_list):
327422
scaler_energy = StandardScaler(copy=False)
328423
scaler_f0 = StandardScaler(copy=False)
329424

330-
for mel, energy, f0, features in train_map:
425+
id_to_remove = []
426+
for result, mel, energy, f0, features in train_map:
427+
if not result:
428+
id_to_remove.append(features["utt_id"])
429+
continue
331430
save_features_to_file(features, "train", config)
332431
# remove outliers
333432
energy = remove_outlier(energy)
433+
f0 = remove_outlier(f0)
434+
# partial fitting of scalers
435+
if len(energy[energy != 0]) == 0 or len(f0[f0 != 0]) == 0:
436+
id_to_remove.append(features["utt_id"])
437+
continue
334438
# partial fitting of scalers
335439
scaler_mel.partial_fit(mel)
336440
scaler_energy.partial_fit(energy[energy != 0].reshape(-1, 1))
337441
scaler_f0.partial_fit(f0[f0 != 0].reshape(-1, 1))
338442

443+
if len(id_to_remove) > 0:
444+
np.save(
445+
os.path.join(config["outdir"], "train_utt_ids.npy"),
446+
[i for i in train_utt_ids if i not in id_to_remove],
447+
)
448+
logging.info(
449+
f"removed {len(id_to_remove)} cause of too many outliers or bad mfa extraction"
450+
)
451+
339452
# save statistics to file
340453
logging.info("Saving computed statistics.")
341454
scaler_list = [(scaler_mel, ""), (scaler_energy, "_energy"), (scaler_f0, "_f0")]

tensorflow_tts/processor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
from tensorflow_tts.processor.ljspeech import LJSpeechProcessor
44
from tensorflow_tts.processor.kss import KSSProcessor
5+
from tensorflow_tts.processor.libritts import LibriTTSProcessor
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2020 TensorFlowTTS Team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Perform preprocessing and raw feature extraction for LibriTTS dataset."""
16+
17+
from dataclasses import dataclass
18+
19+
import numpy as np
20+
import soundfile as sf
21+
from g2p_en import g2p as grapheme_to_phonem
22+
23+
from tensorflow_tts.processor.base_processor import BaseProcessor
24+
25+
g2p = grapheme_to_phonem.G2p()
26+
27+
valid_symbols = g2p.phonemes
28+
valid_symbols.append("SIL")
29+
valid_symbols.append("END")
30+
31+
_punctuation = "!'(),.:;? "
32+
_arpabet = ["@" + s for s in valid_symbols]
33+
34+
LIBRITTS_SYMBOLS = _arpabet + list(_punctuation)
35+
36+
37+
@dataclass
38+
class LibriTTSProcessor(BaseProcessor):
39+
40+
mode: str = "train"
41+
train_f_name: str = "train.txt"
42+
positions = {
43+
"file": 0,
44+
"text": 1,
45+
"speaker_name": 2,
46+
} # positions of file,text,speaker_name after split line
47+
f_extension: str = ".wav"
48+
cleaner_names: str = None
49+
50+
def create_items(self):
51+
with open(
52+
os.path.join(self.data_dir, self.train_f_name), mode="r", encoding="utf-8"
53+
) as f:
54+
for line in f:
55+
parts = line.strip().split(self.delimiter)
56+
wav_path = os.path.join(self.data_dir, parts[self.positions["file"]])
57+
wav_path = (
58+
wav_path + self.f_extension
59+
if wav_path[-len(self.f_extension) :] != self.f_extension
60+
else wav_path
61+
)
62+
text = parts[self.positions["text"]]
63+
speaker_name = parts[self.positions["speaker_name"]]
64+
self.items.append([text, wav_path, speaker_name])
65+
66+
def get_one_sample(self, item):
67+
text, wav_path, speaker_name = item
68+
audio, rate = sf.read(wav_path, dtype="float32")
69+
70+
text_ids = np.asarray(self.text_to_sequence(text), np.int32)
71+
72+
sample = {
73+
"raw_text": text,
74+
"text_ids": text_ids,
75+
"audio": audio,
76+
"utt_id": wav_file.split("/")[-1].split(".")[0],
77+
"speaker_name": speaker_name,
78+
"rate": rate,
79+
}
80+
81+
return sample
82+
83+
def text_to_sequence(self, text):
84+
if (
85+
self.mode == "train"
86+
): # in train mode text should be already transformed to phonemes
87+
return self.symbols_to_ids(clean_g2p(text.split(" ")))
88+
else:
89+
return self.inference_text_to_seq(text)
90+
91+
@staticmethod
92+
def inference_text_to_seq(text: str):
93+
return self.symbols_to_ids(self.text_to_ph(text))
94+
95+
def symbols_to_ids(self, symbols_list: list):
96+
return [self.symbol_to_id[s] for s in symbols_list]
97+
98+
def text_to_ph(self, text: str):
99+
return self.clean_g2p(g2p(text))
100+
101+
def clean_g2p(self, g2p_text: list):
102+
data = []
103+
for i, txt in enumerate(g2p_text):
104+
if i == len(g2p_text) - 1:
105+
if txt != " " and txt != "SIL":
106+
data.append("@" + txt)
107+
else:
108+
data.append(
109+
"@END"
110+
) # TODO try learning without end token and compare results
111+
break
112+
data.append("@" + txt) if txt != " " else data.append(
113+
"@SIL"
114+
) # TODO change it in inference
115+
return data

test/test_base_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def processor(tmpdir):
3333
def mapper_processor(tmpdir):
3434
copyfile("test/files/train.txt", f"{tmpdir}/train.txt")
3535
copyfile("test/files/mapper.json", f"{tmpdir}/mapper.json")
36-
processor = LJ(data_dir=tmpdir, load_mapper=True)
36+
processor = LJ(data_dir=tmpdir, loaded_mapper_path=f"{tmpdir}/mapper.json")
3737
return processor
3838

3939

0 commit comments

Comments
 (0)