Skip to content

Commit a80316c

Browse files
authored
Merge pull request #130 from TensorSpeech/dev/dataset
Refactor Datasets with Pure TF Option
2 parents 3a4d4ee + 66a18fd commit a80316c

File tree

18 files changed

+353
-336
lines changed

18 files changed

+353
-336
lines changed

examples/conformer/config.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ model_config:
6060

6161
learning_config:
6262
augmentations:
63+
use_tf: True
6364
after:
6465
time_masking:
6566
num_masks: 10
@@ -77,7 +78,7 @@ learning_config:
7778
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv
7879
test_paths:
7980
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
80-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
81+
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test
8182

8283
optimizer_config:
8384
warmup_steps: 40000

scripts/create_tfrecords.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,29 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import argparse
17+
from tensorflow_asr.configs.config import Config
1618
from tensorflow_asr.utils.utils import preprocess_paths
1719
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset
20+
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer
1821

1922
modes = ["train", "eval", "test"]
2023

2124
parser = argparse.ArgumentParser(prog="TFRecords Creation")
2225

2326
parser.add_argument("--mode", "-m", type=str, default=None, help=f"Mode in {modes}")
2427

28+
parser.add_argument("--config", type=str, default=None, help="The file path of model configuration file")
29+
2530
parser.add_argument("--tfrecords_dir", type=str, default=None, help="Directory to tfrecords")
2631

2732
parser.add_argument("--tfrecords_shards", type=int, default=16, help="Number of tfrecords shards")
2833

2934
parser.add_argument("--shuffle", default=False, action="store_true", help="Shuffle data or not")
3035

36+
parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")
37+
3138
parser.add_argument("transcripts", nargs="+", type=str, default=None, help="Paths to transcript files")
3239

3340
args = parser.parse_args()
@@ -37,8 +44,15 @@
3744
transcripts = preprocess_paths(args.transcripts)
3845
tfrecords_dir = preprocess_paths(args.tfrecords_dir)
3946

47+
config = Config(args.config)
48+
if args.subwords and os.path.exists(args.subwords):
49+
print("Loading subwords ...")
50+
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
51+
else:
52+
raise ValueError("subwords must be set")
53+
4054
ASRTFRecordDataset(
4155
data_paths=transcripts, tfrecords_dir=tfrecords_dir,
42-
speech_featurizer=None, text_featurizer=None,
56+
speech_featurizer=None, text_featurizer=text_featurizer,
4357
stage=args.mode, shuffle=args.shuffle, tfrecords_shards=args.tfrecords_shards
4458
).create_tfrecords()

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
requirements = [
2121
"tensorflow-datasets>=3.2.1,<4.0.0",
2222
"tensorflow-addons>=0.10.0",
23+
"tensorflow-io>=0.17.0",
2324
"setuptools>=47.1.1",
2425
"librosa>=0.8.0",
2526
"soundfile>=0.10.3",
@@ -35,7 +36,7 @@
3536

3637
setuptools.setup(
3738
name="TensorFlowASR",
38-
version="0.7.0",
39+
version="0.7.1",
3940
author="Huy Le Nguyen",
4041
author_email="[email protected]",
4142
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/augmentations/augments.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import tensorflow as tf
1516
import nlpaug.flow as naf
1617

1718
from .signal_augment import SignalCropping, SignalLoudness, SignalMask, SignalNoise, \
1819
SignalPitch, SignalShift, SignalSpeed, SignalVtlp
19-
from .spec_augment import FreqMasking, TimeMasking
20+
from .spec_augment import FreqMasking, TimeMasking, TFFreqMasking, TFTimeMasking
2021

2122

2223
AUGMENTATIONS = {
@@ -32,12 +33,34 @@
3233
"vtlp": SignalVtlp
3334
}
3435

36+
TFAUGMENTATIONS = {
37+
"freq_masking": TFFreqMasking,
38+
"time_masking": TFTimeMasking,
39+
}
40+
41+
42+
class TFAugmentationExecutor:
43+
def __init__(self, augmentations: list):
44+
self.augmentations = augmentations
45+
46+
@tf.function
47+
def augment(self, inputs):
48+
outputs = inputs
49+
for au in self.augmentations:
50+
outputs = au.augment(outputs)
51+
return outputs
52+
3553

3654
class Augmentation:
3755
def __init__(self, config: dict = None):
3856
if not config: config = {}
39-
self.before = self.parse(config.get("before", {}))
40-
self.after = self.parse(config.get("after", {}))
57+
self.use_tf = config.pop("use_tf", False)
58+
if self.use_tf:
59+
self.before = self.tf_parse(config.pop("before", {}))
60+
self.after = self.tf_parse(config.pop("after", {}))
61+
else:
62+
self.before = self.parse(config.pop("before", {}))
63+
self.after = self.parse(config.pop("after", {}))
4164

4265
@staticmethod
4366
def parse(config: dict) -> list:
@@ -50,3 +73,15 @@ def parse(config: dict) -> list:
5073
aug = au(**value) if value is not None else au()
5174
augmentations.append(aug)
5275
return naf.Sometimes(augmentations)
76+
77+
@staticmethod
78+
def tf_parse(config: dict) -> list:
79+
augmentations = []
80+
for key, value in config.items():
81+
au = TFAUGMENTATIONS.get(key, None)
82+
if au is None:
83+
raise KeyError(f"No tf augmentation named: {key}\n"
84+
f"Available tf augmentations: {TFAUGMENTATIONS.keys()}")
85+
aug = au(**value) if value is not None else au()
86+
augmentations.append(aug)
87+
return TFAugmentationExecutor(augmentations)

tensorflow_asr/augmentations/spec_augment.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,19 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
""" Augmentation on spectrogram: http://arxiv.org/abs/1904.08779 """
16+
1517
import numpy as np
18+
import tensorflow as tf
1619

1720
from nlpaug.flow import Sequential
1821
from nlpaug.util import Action
1922
from nlpaug.model.spectrogram import Spectrogram
2023
from nlpaug.augmenter.spectrogram import SpectrogramAugmenter
2124

25+
from ..utils.utils import shape_list
26+
2227
# ---------------------------- FREQ MASKING ----------------------------
2328

2429

@@ -75,6 +80,35 @@ def __init__(self,
7580
def substitute(self, data):
7681
return self.flow.augment(data)
7782

83+
84+
class TFFreqMasking:
85+
def __init__(self, num_masks: int = 1, mask_factor: float = 27):
86+
self.num_masks = num_masks
87+
self.mask_factor = mask_factor
88+
89+
@tf.function
90+
def augment(self, spectrogram: tf.Tensor):
91+
"""
92+
Masking the frequency channels (shape[1])
93+
Args:
94+
spectrogram: shape (T, num_feature_bins, V)
95+
Returns:
96+
frequency masked spectrogram
97+
"""
98+
T, F, V = shape_list(spectrogram, out_type=tf.int32)
99+
for _ in range(self.num_masks):
100+
f = tf.random.uniform([], minval=0, maxval=self.mask_factor, dtype=tf.int32)
101+
f = tf.minimum(f, F)
102+
f0 = tf.random.uniform([], minval=0, maxval=(F - f), dtype=tf.int32)
103+
mask = tf.concat([
104+
tf.ones([T, f0, V], dtype=spectrogram.dtype),
105+
tf.zeros([T, f, V], dtype=spectrogram.dtype),
106+
tf.ones([T, F - f0 - f, V], dtype=spectrogram.dtype)
107+
], axis=1)
108+
spectrogram = spectrogram * mask
109+
return spectrogram
110+
111+
78112
# ---------------------------- TIME MASKING ----------------------------
79113

80114

@@ -101,9 +135,8 @@ def mask(self, data: np.ndarray) -> np.ndarray:
101135
"""
102136
spectrogram = data.copy()
103137
time = np.random.randint(0, self.mask_factor + 1)
104-
time = min(time, spectrogram.shape[0])
105-
time0 = np.random.randint(0, spectrogram.shape[0] - time + 1)
106138
time = min(time, int(self.p_upperbound * spectrogram.shape[0]))
139+
time0 = np.random.randint(0, spectrogram.shape[0] - time + 1)
107140
spectrogram[time0:time0 + time, :, :] = 0
108141
return spectrogram
109142

@@ -139,3 +172,32 @@ def __init__(self,
139172

140173
def substitute(self, data):
141174
return self.flow.augment(data)
175+
176+
177+
class TFTimeMasking:
178+
def __init__(self, num_masks: int = 1, mask_factor: float = 100, p_upperbound: float = 1.0):
179+
self.num_masks = num_masks
180+
self.mask_factor = mask_factor
181+
self.p_upperbound = p_upperbound
182+
183+
@tf.function
184+
def augment(self, spectrogram: tf.Tensor):
185+
"""
186+
Masking the time channel (shape[0])
187+
Args:
188+
spectrogram: shape (T, num_feature_bins, V)
189+
Returns:
190+
frequency masked spectrogram
191+
"""
192+
T, F, V = shape_list(spectrogram, out_type=tf.int32)
193+
for _ in range(self.num_masks):
194+
t = tf.random.uniform([], minval=0, maxval=self.mask_factor, dtype=tf.int32)
195+
t = tf.minimum(t, tf.cast(tf.cast(T, dtype=tf.float32) * self.p_upperbound, dtype=tf.int32))
196+
t0 = tf.random.uniform([], minval=0, maxval=(T - t), dtype=tf.int32)
197+
mask = tf.concat([
198+
tf.ones([t0, F, V], dtype=spectrogram.dtype),
199+
tf.zeros([t, F, V], dtype=spectrogram.dtype),
200+
tf.ones([T - t0 - t, F, V], dtype=spectrogram.dtype)
201+
], axis=0)
202+
spectrogram = spectrogram * mask
203+
return spectrogram

tensorflow_asr/configs/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(self, config: dict = None):
4040
self.eval_paths = preprocess_paths(config.pop("eval_paths", None))
4141
self.test_paths = preprocess_paths(config.pop("test_paths", None))
4242
self.tfrecords_dir = preprocess_paths(config.pop("tfrecords_dir", None))
43+
self.use_tf = config.pop("use_tf", False)
4344
for k, v in config.items(): setattr(self, k, v)
4445

4546

tensorflow_asr/datasets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@
1313
# limitations under the License.
1414

1515
from .base_dataset import BaseDataset
16-
from .asr_dataset import ASRTFRecordDataset, ASRSliceDataset, ASRTFRecordTestDataset, ASRSliceTestDataset
17-
__all__ = ['BaseDataset', 'ASRTFRecordDataset', 'ASRSliceDataset', 'ASRTFRecordTestDataset', 'ASRSliceTestDataset']
16+
from .asr_dataset import ASRDataset, ASRTFRecordDataset, ASRSliceDataset
17+
__all__ = ['BaseDataset', 'ASRDataset', 'ASRTFRecordDataset', 'ASRSliceDataset']

0 commit comments

Comments
 (0)