-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
97 lines (77 loc) · 3.19 KB
/
dataset.py
File metadata and controls
97 lines (77 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch, numpy as np, torchaudio, librosa
from torch.utils.data import Dataset
from g2p_en import G2p
ARPABET_STRESS = ["0", "1", "2"]
ARPABET_VOWELS = [
"AA", "AE", "AH", "AO", "AW",
"AY", "EH", "ER", "EY", "IH",
"IY", "OW", "OY", "UH", "UW",
]
ARPABET_CONSONANTS = [
"B", "CH", "D", "DH", "F", "G", "HH", "JH", "K", "L", "M", "N",
"NG", "P", "R", "S", "SH", "T", "TH", "V", "W", "Y", "Z", "ZH",
]
TOKEN_BLANK = "<blk>"
TOKEN_PUNCT = [" ", "!", "'", ",", "-", ".", "..", "?"]
TOKENS = [
TOKEN_BLANK,
*TOKEN_PUNCT,
*(v + s for v in ARPABET_VOWELS for s in ARPABET_STRESS),
*ARPABET_CONSONANTS,
]
def tokenize(g2p: G2p, text: str) -> torch.Tensor:
indices = [TOKENS.index(TOKEN_BLANK)]
for token in g2p(text):
indices.append(TOKENS.index(token))
indices.append(TOKENS.index(TOKEN_BLANK))
return torch.tensor(indices)
def preprocess_ljspeech_dataset(path: str, out: str, mel_transform: torch.nn.Module):
g2p = G2p()
index: list[str] = []
with open(f"{path}/metadata.csv") as file:
for line in file:
name, _text, norm_text = line.strip().split("|")
index.append(name)
encoded = tokenize(g2p, norm_text)
encoded_np = encoded.numpy()
waveform, _sample_rate = torchaudio.load(f"{path}/wavs/{name}.wav")
waveform_np = waveform.squeeze(0).numpy()
trimmed_np, _ = librosa.effects.trim(waveform_np, top_db=20)
trimmed = torch.from_numpy(trimmed_np).unsqueeze(0)
mels = mel_transform(trimmed).squeeze(0)
mels_np = mels.numpy()
np.save(f"{out}/{name}-text.npy", encoded_np, allow_pickle=False)
np.save(f"{out}/{name}-mels.npy", mels_np, allow_pickle=False)
with open(f"{out}/index.txt", "w+") as index_file:
for line in index:
index_file.write(line + "\n")
class ProcessedDataset(Dataset):
path: str
index: list[str]
def __init__(self, path: str):
self.path = path
with open(f"{path}/index.txt") as file:
self.index = [name.strip() for name in file]
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
text = np.load(f"{self.path}/{self.index[index]}-text.npy", allow_pickle=False)
mels = np.load(f"{self.path}/{self.index[index]}-mels.npy", allow_pickle=False)
return torch.from_numpy(text), torch.from_numpy(mels)
def __len__(self) -> int:
return len(self.index)
def collate_samples(
batch: list[tuple[torch.Tensor, torch.Tensor]]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch.sort(key=lambda p: p[0].shape[0], reverse=True)
text = torch.nn.utils.rnn.pad_sequence(
[text for text, _ in batch],
batch_first=True,
)
text_lengths = torch.tensor([text.shape[0] for text, _ in batch])
# Move channel dimension to last for padding, and swap back after
mels = torch.nn.utils.rnn.pad_sequence(
[mels.transpose(0, 1) for _, mels in batch],
batch_first=True,
)
mels = mels.transpose(1, 2)
mels_lengths = torch.tensor([mels.shape[1] for _, mels in batch])
return text, text_lengths, mels, mels_lengths