Skip to content

Commit 66a18fd

Browse files
committed
🚀 finish adding option to choose numpy fn or pure tf in dataset
1 parent 9d5f526 commit 66a18fd

File tree

10 files changed

+85
-68
lines changed

10 files changed

+85
-68
lines changed

examples/conformer/train_keras_subword_conformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
strategy = setup_strategy(args.devices)
6060

6161
from tensorflow_asr.configs.config import Config
62-
from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras, ASRDatasetKeras
62+
from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras, ASRSliceDatasetKeras
6363
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
6464
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer
6565
from tensorflow_asr.models.keras.conformer import Conformer
@@ -103,15 +103,15 @@
103103
shuffle=True, buffer_size=args.bfs,
104104
)
105105
else:
106-
train_dataset = ASRDatasetKeras(
106+
train_dataset = ASRSliceDatasetKeras(
107107
data_paths=config.learning_config.dataset_config.train_paths,
108108
speech_featurizer=speech_featurizer,
109109
text_featurizer=text_featurizer,
110110
augmentations=config.learning_config.augmentations,
111111
stage="train", cache=args.cache,
112112
shuffle=True, buffer_size=args.bfs,
113113
)
114-
eval_dataset = ASRDatasetKeras(
114+
eval_dataset = ASRSliceDatasetKeras(
115115
data_paths=config.learning_config.dataset_config.eval_paths,
116116
speech_featurizer=speech_featurizer,
117117
text_featurizer=text_featurizer,

examples/conformer/train_subword_conformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
strategy = setup_strategy(args.devices)
6060

6161
from tensorflow_asr.configs.config import Config
62-
from tensorflow_asr.datasets.asr_dataset import TFASRTFRecordDataset, ASRSliceDataset
62+
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
6363
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
6464
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer
6565
from tensorflow_asr.runners.transducer_runners import TransducerTrainer
@@ -84,7 +84,7 @@
8484
text_featurizer.save_to_file(args.subwords)
8585

8686
if args.tfrecords:
87-
train_dataset = TFASRTFRecordDataset(
87+
train_dataset = ASRTFRecordDataset(
8888
data_paths=config.learning_config.dataset_config.train_paths,
8989
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
9090
speech_featurizer=speech_featurizer,
@@ -94,7 +94,7 @@
9494
stage="train", cache=args.cache,
9595
shuffle=True, buffer_size=args.bfs,
9696
)
97-
eval_dataset = TFASRTFRecordDataset(
97+
eval_dataset = ASRTFRecordDataset(
9898
data_paths=config.learning_config.dataset_config.eval_paths,
9999
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
100100
tfrecords_shards=args.tfrecords_shards,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
setuptools.setup(
3838
name="TensorFlowASR",
39-
version="0.7.0",
39+
version="0.7.1",
4040
author="Huy Le Nguyen",
4141
author_email="[email protected]",
4242
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/augmentations/augments.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@ def augment(self, inputs):
5454
class Augmentation:
5555
def __init__(self, config: dict = None):
5656
if not config: config = {}
57-
use_tf = config.get("use_tf", False)
58-
if use_tf:
59-
self.before = self.tf_parse(config.get("before", {}))
60-
self.after = self.tf_parse(config.get("after", {}))
57+
self.use_tf = config.pop("use_tf", False)
58+
if self.use_tf:
59+
self.before = self.tf_parse(config.pop("before", {}))
60+
self.after = self.tf_parse(config.pop("after", {}))
6161
else:
62-
self.before = self.parse(config.get("before", {}))
63-
self.after = self.parse(config.get("after", {}))
62+
self.before = self.parse(config.pop("before", {}))
63+
self.after = self.parse(config.pop("after", {}))
6464

6565
@staticmethod
6666
def parse(config: dict) -> list:

tensorflow_asr/configs/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(self, config: dict = None):
4040
self.eval_paths = preprocess_paths(config.pop("eval_paths", None))
4141
self.test_paths = preprocess_paths(config.pop("test_paths", None))
4242
self.tfrecords_dir = preprocess_paths(config.pop("tfrecords_dir", None))
43+
self.use_tf = config.pop("use_tf", False)
4344
for k, v in config.items(): setattr(self, k, v)
4445

4546

tensorflow_asr/datasets/asr_dataset.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,6 @@
2828
TFRECORD_SHARDS = 16
2929

3030

31-
def write_tfrecord_file(splitted_entries):
32-
shard_path, entries = splitted_entries
33-
with tf.io.TFRecordWriter(shard_path, options='ZLIB') as out:
34-
for path, audio, indices in entries:
35-
feature = {
36-
"path": bytestring_feature([bytes(path, "utf-8")]),
37-
"audio": bytestring_feature([audio]),
38-
"indices": bytestring_feature([bytes(indices, "utf-8")])
39-
}
40-
example = tf.train.Example(features=tf.train.Features(feature=feature))
41-
out.write(example.SerializeToString())
42-
print_one_line("Processed:", path)
43-
print(f"\nCreated {shard_path}")
44-
45-
4631
class ASRDataset(BaseDataset):
4732
""" Dataset for ASR using Generator """
4833

@@ -54,40 +39,39 @@ def __init__(self,
5439
augmentations: Augmentation = Augmentation(None),
5540
cache: bool = False,
5641
shuffle: bool = False,
57-
use_tf: bool = False,
5842
drop_remainder: bool = True,
5943
buffer_size: int = BUFFER_SIZE):
6044
super(ASRDataset, self).__init__(
6145
data_paths=data_paths, augmentations=augmentations,
6246
cache=cache, shuffle=shuffle, stage=stage, buffer_size=buffer_size,
63-
use_tf=use_tf, drop_remainder=drop_remainder
47+
drop_remainder=drop_remainder
6448
)
6549
self.speech_featurizer = speech_featurizer
6650
self.text_featurizer = text_featurizer
6751

6852
def read_entries(self):
69-
self.lines = []
53+
self.entries = []
7054
for file_path in self.data_paths:
7155
print(f"Reading {file_path} ...")
7256
with tf.io.gfile.GFile(file_path, "r") as f:
7357
temp_lines = f.read().splitlines()
7458
# Skip the header of tsv file
75-
self.lines += temp_lines[1:]
59+
self.entries += temp_lines[1:]
7660
# The files is "\t" seperated
77-
self.lines = [line.split("\t", 2) for line in self.lines]
78-
self.lines = np.array(self.lines)
79-
for i, line in enumerate(self.lines):
80-
self.lines[i][-1] = " ".join([str(x) for x in self.text_featurizer.extract(line[-1]).numpy()])
81-
if self.shuffle: np.random.shuffle(self.lines) # Mix transcripts.tsv
82-
self.total_steps = len(self.lines)
61+
self.entries = [line.split("\t", 2) for line in self.entries]
62+
for i, line in enumerate(self.entries):
63+
self.entries[i][-1] = " ".join([str(x) for x in self.text_featurizer.extract(line[-1]).numpy()])
64+
self.entries = np.array(self.entries)
65+
if self.shuffle: np.random.shuffle(self.entries) # Mix transcripts.tsv
66+
self.total_steps = len(self.entries)
8367

8468
def generator(self):
85-
for path, _, indices in self.lines:
86-
audio = load_and_convert_to_wav(path)
87-
yield path, audio, indices
69+
for path, _, indices in self.entries:
70+
audio = load_and_convert_to_wav(path).numpy()
71+
yield bytes(path, "utf-8"), audio, bytes(indices, "utf-8")
8872

89-
def preprocess(self, path, audio, indices):
90-
def fn(_path, _audio, _indices):
73+
def preprocess(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor):
74+
def fn(_path: bytes, _audio: bytes, _indices: bytes):
9175
with tf.device("/CPU:0"):
9276
signal = read_raw_audio(_audio, self.speech_featurizer.sample_rate)
9377

@@ -111,7 +95,7 @@ def fn(_path, _audio, _indices):
11195
Tout=[tf.string, tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.int32]
11296
)
11397

114-
def tf_preprocess(self, path, audio, indices):
98+
def tf_preprocess(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor):
11599
with tf.device("/CPU:0"):
116100
signal = tf_read_raw_audio(audio, self.speech_featurizer.sample_rate)
117101

@@ -130,7 +114,7 @@ def tf_preprocess(self, path, audio, indices):
130114

131115
return path, features, input_length, label, label_length, prediction, prediction_length
132116

133-
def process(self, dataset, batch_size):
117+
def process(self, dataset: tf.data.Dataset, batch_size: int):
134118
dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE)
135119

136120
if self.cache:
@@ -193,18 +177,34 @@ def __init__(self,
193177
tfrecords_shards: int = TFRECORD_SHARDS,
194178
cache: bool = False,
195179
shuffle: bool = False,
196-
use_tf: bool = False,
180+
drop_remainder: bool = True,
197181
buffer_size: int = BUFFER_SIZE):
198182
super(ASRTFRecordDataset, self).__init__(
199183
stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
200184
data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, buffer_size=buffer_size,
201-
use_tf=use_tf
185+
drop_remainder=drop_remainder
202186
)
203187
self.tfrecords_dir = tfrecords_dir
204188
if tfrecords_shards <= 0: raise ValueError("tfrecords_shards must be positive")
205189
self.tfrecords_shards = tfrecords_shards
206190
if not tf.io.gfile.exists(self.tfrecords_dir): tf.io.gfile.makedirs(self.tfrecords_dir)
207191

192+
@staticmethod
193+
def write_tfrecord_file(splitted_entries):
194+
shard_path, entries = splitted_entries
195+
with tf.io.TFRecordWriter(shard_path, options='ZLIB') as out:
196+
for path, _, indices in entries:
197+
audio = load_and_convert_to_wav(path).numpy()
198+
feature = {
199+
"path": bytestring_feature([bytes(path, "utf-8")]),
200+
"audio": bytestring_feature([audio]),
201+
"indices": bytestring_feature([bytes(indices, "utf-8")])
202+
}
203+
example = tf.train.Example(features=tf.train.Features(feature=feature))
204+
out.write(example.SerializeToString())
205+
print_one_line("Processed:", path)
206+
print(f"\nCreated {shard_path}")
207+
208208
def create_tfrecords(self):
209209
if not tf.io.gfile.exists(self.tfrecords_dir):
210210
tf.io.gfile.makedirs(self.tfrecords_dir)
@@ -217,16 +217,15 @@ def create_tfrecords(self):
217217

218218
self.read_entries()
219219
if not self.total_steps or self.total_steps == 0: return False
220-
entries = np.fromiter(self.generator(), dtype=str)
221220

222221
def get_shard_path(shard_id):
223222
return os.path.join(self.tfrecords_dir, f"{self.stage}_{shard_id}.tfrecord")
224223

225224
shards = [get_shard_path(idx) for idx in range(1, self.tfrecords_shards + 1)]
226225

227-
splitted_entries = np.array_split(entries, self.tfrecords_shards)
226+
splitted_entries = np.array_split(self.entries, self.tfrecords_shards)
228227
with multiprocessing.Pool(self.tfrecords_shards) as pool:
229-
pool.map(write_tfrecord_file, zip(shards, splitted_entries))
228+
pool.map(self.write_tfrecord_file, zip(shards, splitted_entries))
230229

231230
return True
232231

@@ -260,12 +259,13 @@ class ASRSliceDataset(ASRDataset):
260259

261260
@staticmethod
262261
def load(record: tf.Tensor):
263-
audio = load_and_convert_to_wav(record[0])
262+
def fn(path: bytes): return load_and_convert_to_wav(path.decode("utf-8")).numpy()
263+
audio = tf.numpy_function(fn, inp=[record[0]], Tout=tf.string)
264264
return record[0], audio, record[2]
265265

266266
def create(self, batch_size: int):
267267
self.read_entries()
268268
if not self.total_steps or self.total_steps == 0: return None
269-
dataset = tf.data.Dataset.from_tensor_slices(self.lines)
269+
dataset = tf.data.Dataset.from_tensor_slices(self.entries)
270270
dataset = dataset.map(self.load, num_parallel_calls=AUTOTUNE)
271271
return self.process(dataset, batch_size)

tensorflow_asr/datasets/base_dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def __init__(self,
2727
cache: bool = False,
2828
shuffle: bool = False,
2929
buffer_size: int = BUFFER_SIZE,
30-
use_tf: bool = False,
3130
drop_remainder: bool = True,
3231
stage: str = "train"):
3332
self.data_paths = data_paths
@@ -38,7 +37,7 @@ def __init__(self,
3837
raise ValueError("buffer_size must be positive when shuffle is on")
3938
self.buffer_size = buffer_size # shuffle buffer size
4039
self.stage = stage # for defining tfrecords files
41-
self.use_tf = use_tf # whether to use only pure tf in the dataset pipeline
40+
self.use_tf = self.augmentations.use_tf
4241
self.drop_remainder = drop_remainder # whether to drop remainder for multi gpu training
4342
self.total_steps = None # for better training visualization
4443

tensorflow_asr/datasets/keras/asr_dataset.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import tensorflow as tf
1616

17-
from ..asr_dataset import ASRDataset, ASRTFRecordDataset, ASRSliceDataset, AUTOTUNE
17+
from ..asr_dataset import ASRDataset, ASRTFRecordDataset, ASRSliceDataset, AUTOTUNE, TFRECORD_SHARDS
1818
from ..base_dataset import BUFFER_SIZE
1919
from ...featurizers.speech_featurizers import SpeechFeaturizer
2020
from ...featurizers.text_featurizers import TextFeaturizer
@@ -101,30 +101,37 @@ class ASRTFRecordDatasetKeras(ASRDatasetKeras, ASRTFRecordDataset):
101101
""" Keras Dataset for ASR using TFRecords """
102102

103103
def __init__(self,
104-
stage: str,
104+
data_paths: list,
105+
tfrecords_dir: str,
105106
speech_featurizer: SpeechFeaturizer,
106107
text_featurizer: TextFeaturizer,
107-
data_paths: list,
108+
stage: str,
108109
augmentations: Augmentation = Augmentation(None),
110+
tfrecords_shards: int = TFRECORD_SHARDS,
109111
cache: bool = False,
110112
shuffle: bool = False,
111-
use_tf: bool = False,
112113
drop_remainder: bool = True,
113114
buffer_size: int = BUFFER_SIZE):
114115
ASRTFRecordDataset.__init__(
115116
self, stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
116-
data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, use_tf=use_tf,
117-
drop_remainder=drop_remainder, buffer_size=buffer_size
117+
data_paths=data_paths, tfrecords_dir=tfrecords_dir, augmentations=augmentations, cache=cache, shuffle=shuffle,
118+
tfrecords_shards=tfrecords_shards, drop_remainder=drop_remainder, buffer_size=buffer_size
118119
)
119120
ASRDatasetKeras.__init__(
120121
self, stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
121-
data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, use_tf=use_tf,
122+
data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle,
122123
drop_remainder=drop_remainder, buffer_size=buffer_size
123124
)
124125

125126
@tf.function
126-
def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor):
127-
return ASRDatasetKeras.parse(self, path, audio, indices)
127+
def parse(self, record: tf.Tensor):
128+
feature_description = {
129+
"path": tf.io.FixedLenFeature([], tf.string),
130+
"audio": tf.io.FixedLenFeature([], tf.string),
131+
"indices": tf.io.FixedLenFeature([], tf.string)
132+
}
133+
example = tf.io.parse_single_example(record, feature_description)
134+
return ASRDatasetKeras.parse(self, **example)
128135

129136
def process(self, dataset: tf.data.Dataset, batch_size: int):
130137
return ASRDatasetKeras.process(self, dataset, batch_size)
@@ -141,17 +148,16 @@ def __init__(self,
141148
augmentations: Augmentation = Augmentation(None),
142149
cache: bool = False,
143150
shuffle: bool = False,
144-
use_tf: bool = False,
145151
drop_remainder: bool = True,
146152
buffer_size: int = BUFFER_SIZE):
147153
ASRSliceDataset.__init__(
148154
self, stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
149-
data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, use_tf=use_tf,
155+
data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle,
150156
drop_remainder=drop_remainder, buffer_size=buffer_size
151157
)
152158
ASRDatasetKeras.__init__(
153159
self, stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
154-
data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, use_tf=use_tf,
160+
data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle,
155161
drop_remainder=drop_remainder, buffer_size=buffer_size
156162
)
157163

tensorflow_asr/losses/rnnt_losses.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414
# RNNT loss implementation in pure TensorFlow is borrowed from [iamjanvijay's repo](https://github.com/iamjanvijay/rnnt)
1515

1616
import tensorflow as tf
17+
18+
from ..utils.utils import has_gpu_or_tpu
19+
20+
use_cpu = not has_gpu_or_tpu()
21+
1722
try:
1823
from warprnnt_tensorflow import rnnt_loss as warp_rnnt_loss
1924
use_warprnnt = True
2025
except ImportError:
2126
print("Cannot import RNNT loss in warprnnt. Falls back to RNNT in TensorFlow")
22-
print("Note: The RNNT in Tensorflow is not supported for CPU yet")
2327
from tensorflow.python.ops.gen_array_ops import matrix_diag_part_v2
2428
use_warprnnt = False
2529

@@ -208,7 +212,7 @@ def compute_rnnt_loss_and_grad_helper(logits, labels, label_length, logit_length
208212
a = tf.tile(tf.reshape(tf.range(target_max_len - 1, dtype=tf.int64), shape=(1, 1, target_max_len - 1, 1)),
209213
multiples=[batch_size, 1, 1, 1])
210214
b = tf.cast(tf.reshape(labels - 1, shape=(batch_size, 1, target_max_len - 1, 1)), dtype=tf.int64)
211-
# b = tf.where(tf.equal(b, -1), tf.zeros_like(b), b) # for cpu testing (index -1 on cpu will raise errors)
215+
if use_cpu: b = tf.where(tf.equal(b, -1), tf.zeros_like(b), b) # for cpu testing (index -1 on cpu will raise errors)
212216
c = tf.concat([a, b], axis=3)
213217
d = tf.tile(c, multiples=(1, input_max_len, 1, 1))
214218
e = tf.tile(tf.reshape(tf.range(input_max_len, dtype=tf.int64), shape=(1, input_max_len, 1, 1)),

tensorflow_asr/utils/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,10 @@ def get_reduced_length(length, reduction_factor):
160160

161161
def count_non_blank(tensor: tf.Tensor, blank: int or tf.Tensor = 0, axis=None):
162162
return tf.reduce_sum(tf.where(tf.not_equal(tensor, blank), x=tf.ones_like(tensor), y=tf.zeros_like(tensor)), axis=axis)
163+
164+
165+
def has_gpu_or_tpu():
166+
gpus = tf.config.list_logical_devices("GPU")
167+
tpus = tf.config.list_logical_devices("TPU")
168+
if len(gpus) == 0 and len(tpus) == 0: return False
169+
return True

0 commit comments

Comments
 (0)