Skip to content

Commit 6d70eab

Browse files
authored
Merge pull request #82 from TensorSpeech/dev/testing
Update batch for faster testing
2 parents 46edde8 + 288584a commit 6d70eab

File tree

15 files changed

+250
-139
lines changed

15 files changed

+250
-139
lines changed

examples/conformer/masking/masking.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import tensorflow as tf
2-
from tensorflow_asr.utils.utils import shape_list
2+
from tensorflow_asr.utils.utils import shape_list, get_reduced_length
33

44

55
def create_padding_mask(features, input_length, time_reduction_factor):
@@ -14,10 +14,10 @@ def create_padding_mask(features, input_length, time_reduction_factor):
1414
[tf.Tensor]: with shape [B, Tquery, Tkey]
1515
"""
1616
batch_size, padded_time, _, _ = shape_list(features)
17-
reduced_padded_time = tf.math.ceil(padded_time / time_reduction_factor)
17+
reduced_padded_time = get_reduced_length(padded_time, time_reduction_factor)
1818

1919
def create_mask(length):
20-
reduced_length = tf.math.ceil(length / time_reduction_factor)
20+
reduced_length = get_reduced_length(length, time_reduction_factor)
2121
mask = tf.ones([reduced_length, reduced_length], dtype=tf.float32)
2222
return tf.pad(
2323
mask,

examples/conformer/masking/trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from masking import create_padding_mask
44
from tensorflow_asr.runners.transducer_runners import TransducerTrainer, TransducerTrainerGA
55
from tensorflow_asr.losses.rnnt_losses import rnnt_loss
6+
from tensorflow_asr.utils.utils import get_reduced_length
67

78

89
class TrainerWithMasking(TransducerTrainer):
@@ -17,7 +18,7 @@ def _train_step(self, batch):
1718
tape.watch(logits)
1819
per_train_loss = rnnt_loss(
1920
logits=logits, labels=labels, label_length=label_length,
20-
logit_length=tf.cast(tf.math.ceil(input_length / self.model.time_reduction_factor), dtype=tf.int32),
21+
logit_length=get_reduced_length(input_length, self.model.time_reduction_factor),
2122
blank=self.text_featurizer.blank
2223
)
2324
train_loss = tf.nn.compute_average_loss(per_train_loss,
@@ -41,7 +42,7 @@ def _train_step(self, batch):
4142
tape.watch(logits)
4243
per_train_loss = rnnt_loss(
4344
logits=logits, labels=labels, label_length=label_length,
44-
logit_length=tf.cast(tf.math.ceil(input_length / self.model.time_reduction_factor), dtype=tf.int32),
45+
logit_length=get_reduced_length(input_length, self.model.time_reduction_factor),
4546
blank=self.text_featurizer.blank
4647
)
4748
train_loss = tf.nn.compute_average_loss(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
setuptools.setup(
3939
name="TensorFlowASR",
40-
version="0.5.0",
40+
version="0.5.1",
4141
author="Huy Le Nguyen",
4242
author_email="[email protected]",
4343
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/datasets/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,5 @@ Where `prediction` and `prediction_length` are the label prepanded by blank and
5353
**Outputs when iterating in test step**
5454

5555
```python
56-
(path, signals, labels)
56+
(path, features, input_lengths, labels)
5757
```

tensorflow_asr/datasets/asr_dataset.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,15 @@ class ASRTFRecordTestDataset(ASRTFRecordDataset):
241241
def preprocess(self, path, transcript):
242242
with tf.device("/CPU:0"):
243243
signal = read_raw_audio(path.decode("utf-8"), self.speech_featurizer.sample_rate)
244+
245+
features = self.speech_featurizer.extract(signal)
246+
features = tf.convert_to_tensor(features, tf.float32)
247+
input_length = tf.cast(tf.shape(features)[0], tf.int32)
248+
244249
label = self.text_featurizer.extract(transcript.decode("utf-8"))
245-
return path, signal, tf.convert_to_tensor(label, dtype=tf.int32)
250+
label = tf.convert_to_tensor(label, dtype=tf.int32)
251+
252+
return path, features, input_length, label
246253

247254
@tf.function
248255
def parse(self, record):
@@ -256,7 +263,7 @@ def parse(self, record):
256263
return tf.numpy_function(
257264
self.preprocess,
258265
inp=[example["audio"], example["transcript"]],
259-
Tout=(tf.string, tf.float32, tf.int32)
266+
Tout=(tf.string, tf.float32, tf.int32, tf.int32)
260267
)
261268

262269
def process(self, dataset, batch_size):
@@ -273,10 +280,11 @@ def process(self, dataset, batch_size):
273280
batch_size=batch_size,
274281
padded_shapes=(
275282
tf.TensorShape([]),
276-
tf.TensorShape([None]),
283+
tf.TensorShape(self.speech_featurizer.shape),
284+
tf.TensorShape([]),
277285
tf.TensorShape([None]),
278286
),
279-
padding_values=("", 0.0, self.text_featurizer.blank),
287+
padding_values=("", 0.0, 0, self.text_featurizer.blank),
280288
drop_remainder=True
281289
)
282290

@@ -304,15 +312,22 @@ class ASRSliceTestDataset(ASRDataset):
304312
def preprocess(self, path, transcript):
305313
with tf.device("/CPU:0"):
306314
signal = read_raw_audio(path.decode("utf-8"), self.speech_featurizer.sample_rate)
315+
316+
features = self.speech_featurizer.extract(signal)
317+
features = tf.convert_to_tensor(features, tf.float32)
318+
input_length = tf.cast(tf.shape(features)[0], tf.int32)
319+
307320
label = self.text_featurizer.extract(transcript.decode("utf-8"))
308-
return path, signal, tf.convert_to_tensor(label, dtype=tf.int32)
321+
label = tf.convert_to_tensor(label, dtype=tf.int32)
322+
323+
return path, features, input_length, label
309324

310325
@tf.function
311326
def parse(self, record):
312327
return tf.numpy_function(
313328
self.preprocess,
314329
inp=[record[0], record[1]],
315-
Tout=[tf.string, tf.float32, tf.int32]
330+
Tout=[tf.string, tf.float32, tf.int32, tf.int32]
316331
)
317332

318333
def process(self, dataset, batch_size):
@@ -329,10 +344,11 @@ def process(self, dataset, batch_size):
329344
batch_size=batch_size,
330345
padded_shapes=(
331346
tf.TensorShape([]),
332-
tf.TensorShape([None]),
347+
tf.TensorShape(self.speech_featurizer.shape),
348+
tf.TensorShape([]),
333349
tf.TensorShape([None]),
334350
),
335-
padding_values=("", 0.0, self.text_featurizer.blank),
351+
padding_values=("", 0.0, 0, self.text_featurizer.blank),
336352
drop_remainder=True
337353
)
338354

tensorflow_asr/models/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,11 @@ def _build(self, *args, **kwargs):
2727
@abc.abstractmethod
2828
def call(self, inputs, training=False, **kwargs):
2929
raise NotImplementedError()
30+
31+
@abc.abstractmethod
32+
def recognize(self, features, input_lengths, **kwargs):
33+
pass
34+
35+
@abc.abstractmethod
36+
def recognize_beam(self, features, input_lengths, **kwargs):
37+
pass

tensorflow_asr/models/contextnet.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
""" Ref: https://github.com/iankur/ContextNet """
1515

16-
from typing import List
16+
from typing import List, Optional
1717
import tensorflow as tf
1818
from .transducer import Transducer
1919
from ..utils.utils import merge_two_last_dims, get_reduced_length
@@ -234,8 +234,7 @@ def __init__(self,
234234
)
235235
self.dmodel = self.encoder.blocks[-1].dmodel
236236
self.time_reduction_factor = 1
237-
for block in self.encoder.blocks:
238-
self.time_reduction_factor *= block.time_reduction_factor
237+
for block in self.encoder.blocks: self.time_reduction_factor *= block.time_reduction_factor
239238

240239
def call(self, inputs, training=False, **kwargs):
241240
features, input_length, prediction, prediction_length = inputs
@@ -244,8 +243,12 @@ def call(self, inputs, training=False, **kwargs):
244243
outputs = self.joint_net([enc, pred], training=training, **kwargs)
245244
return outputs
246245

247-
def encoder_inference(self, features):
246+
def encoder_inference(self,
247+
features: tf.Tensor,
248+
input_length: Optional[tf.Tensor] = None,
249+
with_batch: bool = False):
248250
with tf.name_scope(f"{self.name}_encoder"):
251+
if with_batch: return self.encoder([features, input_length], training=False)
249252
input_length = tf.expand_dims(tf.shape(features)[0], axis=0)
250253
outputs = tf.expand_dims(features, axis=0)
251254
outputs = self.encoder([outputs, input_length], training=False)

tensorflow_asr/models/ctc.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Optional
1516
import numpy as np
1617
import tensorflow as tf
1718

1819
from . import Model
1920
from ..featurizers.speech_featurizers import TFSpeechFeaturizer
2021
from ..featurizers.text_featurizers import TextFeaturizer
21-
from ..utils.utils import shape_list
22+
from ..utils.utils import shape_list, get_reduced_length
2223

2324

2425
class CtcModel(Model):
@@ -41,20 +42,15 @@ def call(self, inputs, training=False, **kwargs):
4142
# -------------------------------- GREEDY -------------------------------------
4243

4344
@tf.function
44-
def recognize(self, signals):
45-
46-
def extract_fn(signal): return self.speech_featurizer.tf_extract(signal)
47-
48-
features = tf.map_fn(extract_fn, signals,
49-
fn_output_signature=tf.TensorSpec(self.speech_featurizer.shape, dtype=tf.float32))
45+
def recognize(self, features: tf.Tensor, input_length: Optional[tf.Tensor]):
5046
logits = self(features, training=False)
5147
probs = tf.nn.softmax(logits)
5248

53-
def map_fn(prob): return tf.numpy_function(self.perform_greedy, inp=[prob], Tout=tf.string)
49+
def map_fn(prob): return tf.numpy_function(self.__perform_greedy, inp=[prob], Tout=tf.string)
5450

5551
return tf.map_fn(map_fn, probs, fn_output_signature=tf.TensorSpec([], dtype=tf.string))
5652

57-
def perform_greedy(self, probs: np.ndarray):
53+
def __perform_greedy(self, probs: np.ndarray):
5854
from ctc_decoders import ctc_greedy_decoder
5955
decoded = ctc_greedy_decoder(probs, vocabulary=self.text_featurizer.vocab_array)
6056
return tf.convert_to_tensor(decoded, dtype=tf.string)
@@ -71,7 +67,7 @@ def recognize_tflite(self, signal):
7167
features = self.speech_featurizer.tf_extract(signal)
7268
features = tf.expand_dims(features, axis=0)
7369
input_length = shape_list(features)[1]
74-
input_length = input_length // self.base_model.time_reduction_factor
70+
input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor)
7571
input_length = tf.expand_dims(input_length, axis=0)
7672
logits = self(features, training=False)
7773
probs = tf.nn.softmax(logits)
@@ -85,25 +81,20 @@ def recognize_tflite(self, signal):
8581
# -------------------------------- BEAM SEARCH -------------------------------------
8682

8783
@tf.function
88-
def recognize_beam(self, signals, lm=False):
89-
90-
def extract_fn(signal): return self.speech_featurizer.tf_extract(signal)
91-
92-
features = tf.map_fn(extract_fn, signals,
93-
fn_output_signature=tf.TensorSpec(self.speech_featurizer.shape, dtype=tf.float32))
84+
def recognize_beam(self, features: tf.Tensor, input_length: Optional[tf.Tensor], lm: bool = False):
9485
logits = self(features, training=False)
9586
probs = tf.nn.softmax(logits)
9687

97-
def map_fn(prob): return tf.numpy_function(self.perform_beam_search, inp=[prob, lm], Tout=tf.string)
88+
def map_fn(prob): return tf.numpy_function(self.__perform_beam_search, inp=[prob, lm], Tout=tf.string)
9889

9990
return tf.map_fn(map_fn, probs, dtype=tf.string)
10091

101-
def perform_beam_search(self, probs: np.ndarray, lm: bool = False):
92+
def __perform_beam_search(self, probs: np.ndarray, lm: bool = False):
10293
from ctc_decoders import ctc_beam_search_decoder
10394
decoded = ctc_beam_search_decoder(
10495
probs_seq=probs,
10596
vocabulary=self.text_featurizer.vocab_array,
106-
beam_size=self.text_featurizer.decoder_config["beam_width"],
97+
beam_size=self.text_featurizer.decoder_config.beam_width,
10798
ext_scoring_func=self.text_featurizer.scorer if lm else None
10899
)
109100
decoded = decoded[0][-1]
@@ -122,13 +113,13 @@ def recognize_beam_tflite(self, signal):
122113
features = self.speech_featurizer.tf_extract(signal)
123114
features = tf.expand_dims(features, axis=0)
124115
input_length = shape_list(features)[1]
125-
input_length = input_length // self.base_model.time_reduction_factor
116+
input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor)
126117
input_length = tf.expand_dims(input_length, axis=0)
127118
logits = self(features, training=False)
128119
probs = tf.nn.softmax(logits)
129120
decoded = tf.keras.backend.ctc_decode(
130121
y_pred=probs, input_length=input_length, greedy=False,
131-
beam_width=self.text_featurizer.decoder_config["beam_width"]
122+
beam_width=self.text_featurizer.decoder_config.beam_width
132123
)
133124
decoded = tf.cast(decoded[0][0][0], dtype=tf.int32)
134125
transcript = self.text_featurizer.indices2upoints(decoded)

0 commit comments

Comments
 (0)