Skip to content

Commit 3004f0e

Browse files
authored
Merge pull request #86 from TensorSpeech/dev/timestamp
Support naive token level timestamp
2 parents 3eac6c0 + 1640892 commit 3004f0e

File tree

7 files changed

+84
-56
lines changed

7 files changed

+84
-56
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ TensorFlowASR implements some automatic speech recognition architectures such as
2121

2222
## What's New?
2323

24+
- (12/27/2020) Supported _naive_ token level timestamp, see [demo](./examples/demonstration/conformer.py) with flag `--timestamp`
2425
- (12/17/2020) Supported ContextNet [http://arxiv.org/abs/2005.03191](http://arxiv.org/abs/2005.03191)
2526
- (12/12/2020) Add support for using masking
2627
- (11/14/2020) Supported Gradient Accumulation for Training in Larger Batch Size
@@ -219,4 +220,3 @@ For pretrained models, go to [drive](https://drive.google.com/drive/folders/1BD0
219220
Huy Le Nguyen
220221
221222
222-

examples/demonstration/conformer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
parser.add_argument("--beam_width", type=int, default=0, help="Beam width")
3232

33+
parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp")
34+
3335
parser.add_argument("--device", type=int, default=0, help="Device's id to run test on")
3436

3537
parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu")
@@ -66,9 +68,16 @@
6668
features = speech_featurizer.tf_extract(signal)
6769
input_length = get_reduced_length(tf.shape(features)[0], conformer.time_reduction_factor)
6870

69-
if (args.beam_width):
71+
if args.beam_width:
7072
transcript = conformer.recognize_beam(features[None, ...], input_length[None, ...])
73+
print("Transcript:", transcript[0].numpy().decode("UTF-8"))
74+
elif args.timestamp:
75+
transcript, stime, etime, _, _ = conformer.recognize_tflite_with_timestamp(
76+
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state())
77+
print("Transcript:", transcript)
78+
print("Start time:", stime)
79+
print("End time:", etime)
7180
else:
72-
transcript = conformer.recognize(features[None, ...], input_length[None, ...])
73-
74-
tf.print("Transcript:", transcript[0])
81+
transcript, _, _ = conformer.recognize_tflite(
82+
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state())
83+
print("Transcript:", tf.strings.unicode_encode(transcript, "UTF-8").numpy().decode("UTF-8"))

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
setuptools.setup(
3535
name="TensorFlowASR",
36-
version="0.5.5",
36+
version="0.6.0",
3737
author="Huy Le Nguyen",
3838
author_email="[email protected]",
3939
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/featurizers/speech_featurizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def shape(self) -> list:
377377
def stft(self, signal):
378378
return tf.square(
379379
tf.abs(tf.signal.stft(signal, frame_length=self.frame_length,
380-
frame_step=self.frame_step, fft_length=self.nfft)))
380+
frame_step=self.frame_step, fft_length=self.nfft, pad_end=True)))
381381

382382
def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0):
383383
if amin <= 0:

tensorflow_asr/featurizers/text_featurizers.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,7 @@ def __init_vocabulary(self):
119119
self.tokens.insert(self.blank, "") # add blank token to tokens
120120
self.num_classes = len(self.tokens)
121121
self.tokens = tf.convert_to_tensor(self.tokens, dtype=tf.string)
122-
self.upoints = tf.squeeze(
123-
tf.strings.unicode_decode(
124-
self.tokens, "UTF-8").to_tensor(shape=[None, 1])
125-
)
122+
self.upoints = tf.strings.unicode_decode(self.tokens, "UTF-8").to_tensor(shape=[None, 1])
126123

127124
def extract(self, text: str) -> tf.Tensor:
128125
"""
@@ -170,7 +167,7 @@ def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor:
170167
with tf.name_scope("indices2upoints"):
171168
indices = self.normalize_indices(indices)
172169
upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1))
173-
return upoints
170+
return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0)))
174171

175172

176173
class SubwordFeaturizer(TextFeaturizer):
@@ -265,18 +262,25 @@ def iextract(self, indices: tf.Tensor) -> tf.Tensor:
265262
Returns:
266263
transcripts: tf.Tensor of dtype tf.string with dim [B]
267264
"""
268-
indices = self.normalize_indices(indices)
269265
with tf.device("/CPU:0"): # string data is not supported on GPU
270-
def decode(x):
271-
if x[0] == self.blank: x = x[1:]
272-
return self.subwords.decode(x)
273-
274-
text = tf.map_fn(
275-
lambda x: tf.numpy_function(decode, inp=[x], Tout=tf.string),
276-
indices,
277-
fn_output_signature=tf.TensorSpec([], dtype=tf.string)
266+
total = tf.shape(indices)[0]
267+
batch = tf.constant(0, dtype=tf.int32)
268+
transcripts = tf.TensorArray(
269+
dtype=tf.string, size=total, dynamic_size=False, infer_shape=False,
270+
clear_after_read=False, element_shape=tf.TensorShape([])
278271
)
279-
return text
272+
273+
def cond(batch, total, transcripts): return tf.less(batch, total)
274+
275+
def body(batch, total, transcripts):
276+
upoints = self.indices2upoints(indices[batch])
277+
_transcript = tf.strings.unicode_encode(upoints, "UTF-8")
278+
transcripts = transcripts.write(batch, _transcript)
279+
return batch + 1, total, transcripts
280+
281+
_, _, transcripts = tf.while_loop(cond, body, loop_vars=[batch, total, transcripts])
282+
283+
return transcripts.stack()
280284

281285
@tf.function(
282286
input_signature=[
@@ -295,6 +299,4 @@ def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor:
295299
with tf.name_scope("indices2upoints"):
296300
indices = self.normalize_indices(indices)
297301
upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1))
298-
# upoints now has shape [None, max_subword_length]
299-
shape = tf.shape(upoints)
300-
return tf.reshape(upoints, [shape[0] * shape[1]]) # flatten
302+
return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0)))

tensorflow_asr/models/transducer.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,9 @@
2323
from ..featurizers.text_featurizers import TextFeaturizer
2424
from .layers.embedding import Embedding
2525

26-
Hypothesis = collections.namedtuple(
27-
"Hypothesis",
28-
("index", "prediction", "states")
29-
)
26+
Hypothesis = collections.namedtuple("Hypothesis", ("index", "prediction", "states"))
3027

31-
BeamHypothesis = collections.namedtuple(
32-
"BeamHypothesis",
33-
("score", "indices", "prediction", "states")
34-
)
28+
BeamHypothesis = collections.namedtuple("BeamHypothesis", ("score", "indices", "prediction", "states"))
3529

3630

3731
class TransducerPrediction(tf.keras.Model):
@@ -233,6 +227,7 @@ def __init__(self,
233227
bias_regularizer=bias_regularizer,
234228
name=f"{name}_joint"
235229
)
230+
self.time_reduction_factor = 1
236231

237232
def _build(self, input_shape):
238233
inputs = tf.keras.Input(shape=input_shape, dtype=tf.float32)
@@ -369,6 +364,29 @@ def recognize_tflite(self, signal, predicted, states):
369364
hypothesis.states
370365
)
371366

367+
def recognize_tflite_with_timestamp(self, signal, predicted, states):
368+
features = self.speech_featurizer.tf_extract(signal)
369+
encoded = self.encoder_inference(features)
370+
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states)
371+
indices = self.text_featurizer.normalize_indices(hypothesis.prediction)
372+
upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length]
373+
374+
num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32)
375+
total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step
376+
377+
stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32)
378+
stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32)
379+
380+
etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32)
381+
etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32)
382+
383+
non_blank = tf.where(tf.not_equal(upoints, 0))
384+
non_blank_transcript = tf.gather_nd(upoints, non_blank)
385+
non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
386+
non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
387+
388+
return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.prediction, hypothesis.states
389+
372390
def _perform_greedy_batch(self,
373391
encoded: tf.Tensor,
374392
encoded_length: tf.Tensor,
@@ -400,7 +418,7 @@ def body(batch, total, encoded, encoded_length, decoded):
400418

401419
batch, total, _, _, decoded = tf.while_loop(
402420
condition, body,
403-
loop_vars=(batch, total, encoded, encoded_length, decoded),
421+
loop_vars=[batch, total, encoded, encoded_length, decoded],
404422
parallel_iterations=parallel_iterations,
405423
swap_memory=True,
406424
)
@@ -419,45 +437,43 @@ def _perform_greedy(self,
419437
total = encoded_length
420438

421439
hypothesis = Hypothesis(
422-
index=tf.constant(0, dtype=tf.int32),
423-
prediction=tf.ones([total + 1], dtype=tf.int32) * self.text_featurizer.blank,
440+
index=tf.constant(self.text_featurizer.blank, dtype=tf.int32),
441+
prediction=tf.ones([total], dtype=tf.int32) * self.text_featurizer.blank,
424442
states=states
425443
)
426444

427445
def condition(time, total, encoded, hypothesis): return tf.less(time, total)
428446

429447
def body(time, total, encoded, hypothesis):
430-
predicted = tf.gather_nd(hypothesis.prediction, tf.expand_dims(hypothesis.index, axis=-1))
431-
432-
ytu, new_states = self.decoder_inference(
448+
ytu, states = self.decoder_inference(
433449
# avoid using [index] in tflite
434450
encoded=tf.gather_nd(encoded, tf.expand_dims(time, axis=-1)),
435-
predicted=predicted,
451+
predicted=hypothesis.index,
436452
states=hypothesis.states
437453
)
438-
new_predicted = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []
454+
predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []
439455

440-
index, new_predicted, new_states = tf.cond(
441-
tf.equal(new_predicted, self.text_featurizer.blank),
442-
true_fn=lambda: (hypothesis.index, predicted, hypothesis.states),
443-
false_fn=lambda: (hypothesis.index + 1, new_predicted, new_states)
456+
index, predict, states = tf.cond(
457+
tf.equal(predict, self.text_featurizer.blank),
458+
true_fn=lambda: (hypothesis.index, predict, hypothesis.states),
459+
false_fn=lambda: (predict, predict, states) # update if the new prediction is a non-blank
444460
)
445461

446462
hypothesis = Hypothesis(
447463
index=index,
448464
prediction=tf.tensor_scatter_nd_update(
449465
hypothesis.prediction,
450-
indices=tf.reshape(index, [1, 1]),
451-
updates=tf.expand_dims(new_predicted, axis=-1)
466+
indices=tf.reshape(time, [1, 1]),
467+
updates=tf.expand_dims(predict, axis=-1)
452468
),
453-
states=new_states
469+
states=states
454470
)
455471

456472
return time + 1, total, encoded, hypothesis
457473

458474
time, total, encoded, hypothesis = tf.while_loop(
459475
condition, body,
460-
loop_vars=(time, total, encoded, hypothesis),
476+
loop_vars=[time, total, encoded, hypothesis],
461477
parallel_iterations=parallel_iterations,
462478
swap_memory=swap_memory
463479
)
@@ -512,7 +528,7 @@ def body(batch, total, encoded, encoded_length, decoded):
512528

513529
batch, total, _, _, decoded = tf.while_loop(
514530
condition, body,
515-
loop_vars=(batch, total, encoded, encoded_length, decoded),
531+
loop_vars=[batch, total, encoded, encoded_length, decoded],
516532
parallel_iterations=parallel_iterations,
517533
swap_memory=True,
518534
)
@@ -626,23 +642,23 @@ def predict_body(pred, A, A_i, B):
626642

627643
_, A, A_i, B = tf.while_loop(
628644
predict_condition, predict_body,
629-
loop_vars=(0, A, A_i, B),
645+
loop_vars=[0, A, A_i, B],
630646
parallel_iterations=parallel_iterations, swap_memory=swap_memory
631647
)
632648

633649
return beam + 1, beam_width, A, A_i, B
634650

635651
_, _, A, A_i, B = tf.while_loop(
636652
beam_condition, beam_body,
637-
loop_vars=(0, beam_width, A, A_i, B),
653+
loop_vars=[0, beam_width, A, A_i, B],
638654
parallel_iterations=parallel_iterations, swap_memory=swap_memory
639655
)
640656

641657
return time + 1, total, B
642658

643659
_, _, B = tf.while_loop(
644660
condition, body,
645-
loop_vars=(0, total, B),
661+
loop_vars=[0, total, B],
646662
parallel_iterations=parallel_iterations, swap_memory=swap_memory
647663
)
648664

@@ -665,9 +681,10 @@ def predict_body(pred, A, A_i, B):
665681

666682
# -------------------------------- TFLITE -------------------------------------
667683

668-
def make_tflite_function(self, greedy: bool = True):
684+
def make_tflite_function(self, timestamp: bool = False):
685+
tflite_func = self.recognize_tflite_with_timestamp if timestamp else self.recognize_tflite
669686
return tf.function(
670-
self.recognize_tflite,
687+
tflite_func,
671688
input_signature=[
672689
tf.TensorSpec([None], dtype=tf.float32),
673690
tf.TensorSpec([], dtype=tf.int32),

tensorflow_asr/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _body(i, result, yseqs, U):
133133
_, result, _, _ = tf.while_loop(
134134
_cond,
135135
_body,
136-
loop_vars=(i, result, yseqs, U),
136+
loop_vars=[i, result, yseqs, U],
137137
shape_invariants=(
138138
tf.TensorShape([]),
139139
tf.TensorShape([None]),

0 commit comments

Comments
 (0)