Skip to content

Commit 83d1b70

Browse files
committed
🚀 Update transducer beam search and tester
1 parent f7a5ea7 commit 83d1b70

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
requirements = [
2121
"tensorflow-datasets>=3.2.1,<4.0.0",
22+
"tensorflow-metadata>=0.26.0",
2223
"tensorflow-addons>=0.10.0",
2324
"setuptools>=47.1.1",
2425
"librosa>=0.7.2",
@@ -32,12 +33,11 @@
3233
"tqdm>=4.51.0",
3334
"colorama>=0.4.3",
3435
"nlpaug>=1.0.1",
35-
"absl-py>=0.9,<0.11"
3636
]
3737

3838
setuptools.setup(
3939
name="TensorFlowASR",
40-
version="0.4.3",
40+
version="0.4.4",
4141
author="Huy Le Nguyen",
4242
author_email="[email protected]",
4343
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/models/transducer.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def execute(signal: tf.Tensor):
462462
return tf.map_fn(execute, signals, fn_output_signature=tf.TensorSpec([], dtype=tf.string))
463463

464464
def perform_beam_search(self, encoded, lm=False):
465-
with tf.name_scope(f"{self.name}_beam_search"):
465+
with tf.device("/CPU:0"), tf.name_scope(f"{self.name}_beam_search"):
466466
beam_width = tf.cond(
467467
tf.less(self.text_featurizer.decoder_config.beam_width, self.text_featurizer.num_classes),
468468
true_fn=lambda: self.text_featurizer.decoder_config.beam_width,
@@ -520,9 +520,9 @@ def beam_condition(beam, beam_width, A, A_i, B): return tf.less(beam, beam_width
520520
def beam_body(beam, beam_width, A, A_i, B):
521521
y_hat_score, y_hat_score_index = tf.math.top_k(A.score.stack(), k=1)
522522
y_hat_score = y_hat_score[0]
523-
y_hat_index = tf.gather_nd(A.indices.stack(), tf.expand_dims(y_hat_score_index[0], axis=-1))
524-
y_hat_prediction = tf.gather_nd(A.prediction.stack(), tf.expand_dims(y_hat_score_index[0], axis=-1))
525-
y_hat_states = tf.gather_nd(A.states.stack(), tf.expand_dims(y_hat_score_index[0], axis=-1))
523+
y_hat_index = tf.gather_nd(A.indices.stack(), y_hat_score_index)
524+
y_hat_prediction = tf.gather_nd(A.prediction.stack(), y_hat_score_index)
525+
y_hat_states = tf.gather_nd(A.states.stack(), y_hat_score_index)
526526

527527
ytu, new_states = self.decoder_inference(encoded=encoded_t, predicted=y_hat_index, states=y_hat_states)
528528

@@ -571,11 +571,16 @@ def predict_body(pred, A, A_i, B):
571571

572572
_, _, B = tf.while_loop(condition, body, loop_vars=(0, total, B))
573573

574-
y_hat_score, y_hat_score_index = tf.math.top_k(B.score.stack(), k=1)
574+
scores = B.score.stack()
575+
if self.text_featurizer.decoder_config.norm_score:
576+
prediction_lengths = tf.strings.length(B.prediction.stack(), unit="UTF8_CHAR")
577+
scores /= tf.cast(prediction_lengths, dtype=scores.dtype)
578+
579+
y_hat_score, y_hat_score_index = tf.math.top_k(scores, k=1)
575580
y_hat_score = y_hat_score[0]
576-
y_hat_index = tf.gather_nd(B.indices.stack(), tf.expand_dims(y_hat_score_index[0], axis=-1))
577-
y_hat_prediction = tf.gather_nd(B.prediction.stack(), tf.expand_dims(y_hat_score_index[0], axis=-1))
578-
y_hat_states = tf.gather_nd(B.states.stack(), tf.expand_dims(y_hat_score_index[0], axis=-1))
581+
y_hat_index = tf.gather_nd(B.indices.stack(), y_hat_score_index)
582+
y_hat_prediction = tf.gather_nd(B.prediction.stack(), y_hat_score_index)
583+
y_hat_states = tf.gather_nd(B.states.stack(), y_hat_score_index)
579584

580585
return Hypothesis(
581586
index=y_hat_index,

tensorflow_asr/runners/base_runners.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,11 +444,11 @@ def _test_step(self, batch):
444444

445445
labels = self.model.text_featurizer.iextract(labels)
446446
greed_pred = self.model.recognize(signals)
447+
beam_pred = beam_lm_pred = tf.constant([""], dtype=tf.string)
447448
if self.model.text_featurizer.decoder_config.beam_width > 0:
448449
beam_pred = self.model.recognize_beam(signals, lm=False)
450+
if self.model.text_featurizer.decoder_config.lm_config:
449451
beam_lm_pred = self.model.recognize_beam(signals, lm=True)
450-
else:
451-
beam_pred = beam_lm_pred = tf.constant([""], dtype=tf.string)
452452

453453
return file_paths, labels, greed_pred, beam_pred, beam_lm_pred
454454

0 commit comments

Comments
 (0)