Skip to content

Commit 4de48ca

Browse files
authored
Merge pull request #1451 from yt605155624/fix_ci_waveflow
[TTS]Fix ci for waveflow
2 parents cd23be7 + 67ec624 commit 4de48ca

File tree

4 files changed

+268
-7
lines changed

4 files changed

+268
-7
lines changed
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import six
15+
from paddle.io import Dataset
16+
17+
__all__ = [
18+
"split",
19+
"TransformDataset",
20+
"CacheDataset",
21+
"TupleDataset",
22+
"DictDataset",
23+
"SliceDataset",
24+
"SubsetDataset",
25+
"FilterDataset",
26+
"ChainDataset",
27+
]
28+
29+
30+
def split(dataset, first_size):
31+
"""A utility function to split a dataset into two datasets."""
32+
first = SliceDataset(dataset, 0, first_size)
33+
second = SliceDataset(dataset, first_size, len(dataset))
34+
return first, second
35+
36+
37+
class TransformDataset(Dataset):
38+
def __init__(self, dataset, transform):
39+
"""Dataset which is transformed from another with a transform.
40+
41+
Args:
42+
dataset (Dataset): the base dataset.
43+
transform (callable): the transform which takes an example of the base dataset as parameter and return a new example.
44+
"""
45+
self._dataset = dataset
46+
self._transform = transform
47+
48+
def __len__(self):
49+
return len(self._dataset)
50+
51+
def __getitem__(self, i):
52+
in_data = self._dataset[i]
53+
return self._transform(in_data)
54+
55+
56+
class CacheDataset(Dataset):
57+
def __init__(self, dataset):
58+
"""A lazy cache of the base dataset.
59+
60+
Args:
61+
dataset (Dataset): the base dataset to cache.
62+
"""
63+
self._dataset = dataset
64+
self._cache = dict()
65+
66+
def __len__(self):
67+
return len(self._dataset)
68+
69+
def __getitem__(self, i):
70+
if i not in self._cache:
71+
self._cache[i] = self._dataset[i]
72+
return self._cache[i]
73+
74+
75+
class TupleDataset(Dataset):
76+
def __init__(self, *datasets):
77+
"""A compound dataset made from several datasets of the same length. An example of the `TupleDataset` is a tuple of examples from the constituent datasets.
78+
79+
Args:
80+
datasets: tuple[Dataset], the constituent datasets.
81+
"""
82+
if not datasets:
83+
raise ValueError("no datasets are given")
84+
length = len(datasets[0])
85+
for i, dataset in enumerate(datasets):
86+
if len(dataset) != length:
87+
raise ValueError("all the datasets should have the same length."
88+
"dataset {} has a different length".format(i))
89+
self._datasets = datasets
90+
self._length = length
91+
92+
def __getitem__(self, index):
93+
# SOA
94+
batches = [dataset[index] for dataset in self._datasets]
95+
if isinstance(index, slice):
96+
length = len(batches[0])
97+
# AOS
98+
return [
99+
tuple([batch[i] for batch in batches])
100+
for i in six.moves.range(length)
101+
]
102+
else:
103+
return tuple(batches)
104+
105+
def __len__(self):
106+
return self._length
107+
108+
109+
class DictDataset(Dataset):
110+
def __init__(self, **datasets):
111+
"""
112+
A compound dataset made from several datasets of the same length. An
113+
example of the `DictDataset` is a dict of examples from the constituent
114+
datasets.
115+
116+
WARNING: paddle does not have a good support for DictDataset, because
117+
every batch yield from a DataLoader is a list, but it cannot be a dict.
118+
So you have to provide a collate function because you cannot use the
119+
default one.
120+
121+
Args:
122+
datasets: Dict[Dataset], the constituent datasets.
123+
"""
124+
if not datasets:
125+
raise ValueError("no datasets are given")
126+
length = None
127+
for key, dataset in six.iteritems(datasets):
128+
if length is None:
129+
length = len(dataset)
130+
elif len(dataset) != length:
131+
raise ValueError(
132+
"all the datasets should have the same length."
133+
"dataset {} has a different length".format(key))
134+
self._datasets = datasets
135+
self._length = length
136+
137+
def __getitem__(self, index):
138+
batches = {
139+
key: dataset[index]
140+
for key, dataset in six.iteritems(self._datasets)
141+
}
142+
if isinstance(index, slice):
143+
length = len(six.next(six.itervalues(batches)))
144+
return [{key: batch[i]
145+
for key, batch in six.iteritems(batches)}
146+
for i in six.moves.range(length)]
147+
else:
148+
return batches
149+
150+
def __len__(self):
151+
return self._length
152+
153+
154+
class SliceDataset(Dataset):
155+
def __init__(self, dataset, start, finish, order=None):
156+
"""A Dataset which is a slice of the base dataset.
157+
158+
Args:
159+
dataset (Dataset): the base dataset.
160+
start (int): the start of the slice.
161+
finish (int): the end of the slice, not inclusive.
162+
order (List[int], optional): the order, it is a permutation of the valid example ids of the base dataset. If `order` is provided, the slice is taken in `order`. Defaults to None.
163+
"""
164+
if start < 0 or finish > len(dataset):
165+
raise ValueError("subset overruns the dataset.")
166+
self._dataset = dataset
167+
self._start = start
168+
self._finish = finish
169+
self._size = finish - start
170+
171+
if order is not None and len(order) != len(dataset):
172+
raise ValueError(
173+
"order should have the same length as the dataset"
174+
"len(order) = {} which does not euqals len(dataset) = {} ".
175+
format(len(order), len(dataset)))
176+
self._order = order
177+
178+
def __len__(self):
179+
return self._size
180+
181+
def __getitem__(self, i):
182+
if i >= 0:
183+
if i >= self._size:
184+
raise IndexError('dataset index out of range')
185+
index = self._start + i
186+
else:
187+
if i < -self._size:
188+
raise IndexError('dataset index out of range')
189+
index = self._finish + i
190+
191+
if self._order is not None:
192+
index = self._order[index]
193+
return self._dataset[index]
194+
195+
196+
class SubsetDataset(Dataset):
197+
def __init__(self, dataset, indices):
198+
"""A Dataset which is a subset of the base dataset.
199+
200+
Args:
201+
dataset (Dataset): the base dataset.
202+
indices (Iterable[int]): the indices of the examples to pick.
203+
"""
204+
self._dataset = dataset
205+
if len(indices) > len(dataset):
206+
raise ValueError("subset's size larger that dataset's size!")
207+
self._indices = indices
208+
self._size = len(indices)
209+
210+
def __len__(self):
211+
return self._size
212+
213+
def __getitem__(self, i):
214+
index = self._indices[i]
215+
return self._dataset[index]
216+
217+
218+
class FilterDataset(Dataset):
219+
def __init__(self, dataset, filter_fn):
220+
"""A filtered dataset.
221+
222+
Args:
223+
dataset (Dataset): the base dataset.
224+
filter_fn (callable): a callable which takes an example of the base dataset and return a boolean.
225+
"""
226+
self._dataset = dataset
227+
self._indices = [
228+
i for i in range(len(dataset)) if filter_fn(dataset[i])
229+
]
230+
self._size = len(self._indices)
231+
232+
def __len__(self):
233+
return self._size
234+
235+
def __getitem__(self, i):
236+
index = self._indices[i]
237+
return self._dataset[index]
238+
239+
240+
class ChainDataset(Dataset):
241+
def __init__(self, *datasets):
242+
"""A concatenation of the several datasets which the same structure.
243+
244+
Args:
245+
datasets (Iterable[Dataset]): datasets to concat.
246+
"""
247+
self._datasets = datasets
248+
249+
def __len__(self):
250+
return sum(len(dataset) for dataset in self._datasets)
251+
252+
def __getitem__(self, i):
253+
if i < 0:
254+
raise IndexError("ChainDataset doesnot support negative indexing.")
255+
256+
for dataset in self._datasets:
257+
if i < len(dataset):
258+
return dataset[i]
259+
i -= len(dataset)
260+
261+
raise IndexError("dataset index out of range")

paddlespeech/t2s/exps/waveflow/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from paddle.io import DataLoader
2020
from paddle.io import DistributedBatchSampler
2121

22-
from paddlespeech.t2s.data import dataset
22+
from paddlespeech.t2s.datasets import dataset
2323
from paddlespeech.t2s.exps.waveflow.config import get_cfg_defaults
2424
from paddlespeech.t2s.exps.waveflow.ljspeech import LJSpeech
2525
from paddlespeech.t2s.exps.waveflow.ljspeech import LJSpeechClipCollector

paddlespeech/t2s/frontend/zh_normalization/num.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def verbalize_digit(value_string: str, alt_one=False) -> str:
208208
result_symbols = [DIGITS[digit] for digit in value_string]
209209
result = ''.join(result_symbols)
210210
if alt_one:
211-
result.replace("一", "幺")
211+
result = result.replace("一", "幺")
212212
return result
213213

214214

paddlespeech/t2s/models/waveflow.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ def fold(x, n_group):
3333
"""Fold audio or spectrogram's temporal dimension in to groups.
3434
3535
Args:
36-
x(Tensor): The input tensor. shape=(\*, time_steps)
36+
x(Tensor): The input tensor. shape=(*, time_steps)
3737
n_group(int): The size of a group.
3838
3939
Returns:
40-
Tensor: Folded tensor. shape=(\*, time_steps // n_group, group)
40+
Tensor: Folded tensor. shape=(*, time_steps // n_group, group)
4141
"""
4242
spatial_shape = list(x.shape[:-1])
4343
time_steps = paddle.shape(x)[-1]
@@ -98,11 +98,11 @@ def forward(self, x, trim_conv_artifact=False):
9898
trim_conv_artifact(bool, optional, optional): Trim deconvolution artifact at each layer. Defaults to False.
9999
100100
Returns:
101-
Tensor: The upsampled spectrogram. shape=(batch_size, input_channels, time_steps \* upsample_factor)
101+
Tensor: The upsampled spectrogram. shape=(batch_size, input_channels, time_steps * upsample_factor)
102102
103103
Notes:
104104
If trim_conv_artifact is ``True``, the output time steps is less
105-
than ``time_steps \* upsample_factors``.
105+
than ``time_steps * upsample_factors``.
106106
"""
107107
x = paddle.unsqueeze(x, 1) # (B, C, T) -> (B, 1, C, T)
108108
for layer in self:
@@ -641,7 +641,7 @@ def infer(self, mel):
641641
mel(np.ndarray): Mel spectrogram of an utterance(in log-magnitude). shape=(C_mel, T_mel)
642642
643643
Returns:
644-
Tensor: The synthesized audio, where``T <= T_mel \* upsample_factors``. shape=(B, T)
644+
Tensor: The synthesized audio, where``T <= T_mel * upsample_factors``. shape=(B, T)
645645
"""
646646
start = time.time()
647647
condition = self.encoder(mel, trim_conv_artifact=True) # (B, C, T)

0 commit comments

Comments
 (0)