Skip to content

Commit 98e13e2

Browse files
committed
✍️ fix transducer batch recognition
1 parent a228193 commit 98e13e2

File tree

2 files changed

+112
-80
lines changed

2 files changed

+112
-80
lines changed

tensorflow_asr/models/transducer.py

Lines changed: 73 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import tensorflow as tf
1818

1919
from . import Model
20-
from ..utils.utils import get_rnn, shape_list, count_non_blank
20+
from ..utils.utils import get_rnn, shape_list, count_non_blank, pad_prediction_tfarray
2121
from ..featurizers.speech_featurizers import SpeechFeaturizer
2222
from ..featurizers.text_featurizers import TextFeaturizer
2323
from .layers.embedding import Embedding
@@ -400,42 +400,40 @@ def _perform_greedy_batch(self,
400400
encoded: tf.Tensor,
401401
encoded_length: tf.Tensor,
402402
parallel_iterations: int = 10,
403-
swap_memory: bool = False):
404-
total_batch, total_time, _ = shape_list(encoded)
405-
batch = tf.constant(0, dtype=tf.int32)
403+
swap_memory: bool = False,
404+
version: str = 'v1'):
405+
with tf.name_scope(f"{self.name}_perform_greedy_batch"):
406+
total_batch = tf.shape(encoded)[0]
407+
batch = tf.constant(0, dtype=tf.int32)
406408

407-
decoded = tf.TensorArray(
408-
dtype=tf.int32, size=total_batch, dynamic_size=False,
409-
clear_after_read=False, element_shape=None
410-
)
411-
412-
def condition(batch, _): return tf.less(batch, total_batch)
409+
greedy_fn = self._perform_greedy if version == 'v1' else self._perform_greedy_v2
413410

414-
def body(batch, decoded):
415-
hypothesis = self._perform_greedy(
416-
encoded=encoded[batch],
417-
encoded_length=encoded_length[batch],
418-
predicted=tf.constant(self.text_featurizer.blank, dtype=tf.int32),
419-
states=self.predict_net.get_initial_state(),
420-
parallel_iterations=parallel_iterations,
421-
swap_memory=swap_memory
411+
decoded = tf.TensorArray(
412+
dtype=tf.int32, size=total_batch, dynamic_size=False,
413+
clear_after_read=False, element_shape=tf.TensorShape([None])
422414
)
423-
prediction = tf.pad(
424-
hypothesis.prediction,
425-
paddings=[[0, 2 * (total_time - encoded_length[batch])]],
426-
mode="CONSTANT", constant_values=self.text_featurizer.blank
415+
416+
def condition(batch, _): return tf.less(batch, total_batch)
417+
418+
def body(batch, decoded):
419+
hypothesis = greedy_fn(
420+
encoded=encoded[batch],
421+
encoded_length=encoded_length[batch],
422+
predicted=tf.constant(self.text_featurizer.blank, dtype=tf.int32),
423+
states=self.predict_net.get_initial_state(),
424+
parallel_iterations=parallel_iterations,
425+
swap_memory=swap_memory
426+
)
427+
decoded = decoded.write(batch, hypothesis.prediction)
428+
return batch + 1, decoded
429+
430+
batch, decoded = tf.while_loop(
431+
condition, body, loop_vars=[batch, decoded],
432+
parallel_iterations=parallel_iterations, swap_memory=True,
427433
)
428-
decoded = decoded.write(batch, prediction)
429-
return batch + 1, decoded
430-
431-
batch, decoded = tf.while_loop(
432-
condition, body,
433-
loop_vars=[batch, decoded],
434-
parallel_iterations=parallel_iterations,
435-
swap_memory=True,
436-
)
437434

438-
return self.text_featurizer.iextract(decoded.stack())
435+
decoded = pad_prediction_tfarray(decoded, blank=self.text_featurizer.blank)
436+
return self.text_featurizer.iextract(decoded.stack())
439437

440438
def _perform_greedy(self,
441439
encoded: tf.Tensor,
@@ -457,12 +455,12 @@ def _perform_greedy(self,
457455
states=states
458456
)
459457

460-
def condition(_time, _total, _encoded, _hypothesis): return tf.less(_time, _total)
458+
def condition(_time, _hypothesis): return tf.less(_time, total)
461459

462-
def body(_time, _total, _encoded, _hypothesis):
460+
def body(_time, _hypothesis):
463461
ytu, _states = self.decoder_inference(
464462
# avoid using [index] in tflite
465-
encoded=tf.gather_nd(_encoded, tf.reshape(_time, shape=[1])),
463+
encoded=tf.gather_nd(encoded, tf.reshape(_time, shape=[1])),
466464
predicted=_hypothesis.index,
467465
states=_hypothesis.states
468466
)
@@ -480,13 +478,11 @@ def body(_time, _total, _encoded, _hypothesis):
480478
_prediction = _hypothesis.prediction.write(_time, _predict)
481479
_hypothesis = Hypothesis(index=_index, prediction=_prediction, states=_states)
482480

483-
return _time + 1, _total, _encoded, _hypothesis
481+
return _time + 1, _hypothesis
484482

485-
_, _, _, hypothesis = tf.while_loop(
483+
time, hypothesis = tf.while_loop(
486484
condition, body,
487-
loop_vars=[time, total, encoded, hypothesis],
488-
parallel_iterations=parallel_iterations,
489-
swap_memory=swap_memory
485+
loop_vars=[time, hypothesis], parallel_iterations=parallel_iterations, swap_memory=swap_memory
490486
)
491487

492488
return Hypothesis(index=hypothesis.index, prediction=hypothesis.prediction.stack(), states=hypothesis.states)
@@ -512,12 +508,12 @@ def _perform_greedy_v2(self,
512508
states=states
513509
)
514510

515-
def condition(_time, _total, _encoded, _hypothesis): return tf.less(_time, _total)
511+
def condition(_time, _hypothesis): return tf.less(_time, total)
516512

517-
def body(_time, _total, _encoded, _hypothesis):
513+
def body(_time, _hypothesis):
518514
ytu, _states = self.decoder_inference(
519515
# avoid using [index] in tflite
520-
encoded=tf.gather_nd(_encoded, tf.reshape(_time, shape=[1])),
516+
encoded=tf.gather_nd(encoded, tf.reshape(_time, shape=[1])),
521517
predicted=_hypothesis.index,
522518
states=_hypothesis.states
523519
)
@@ -531,13 +527,11 @@ def body(_time, _total, _encoded, _hypothesis):
531527
_prediction = _hypothesis.prediction.write(_time, _predict)
532528
_hypothesis = Hypothesis(index=_index, prediction=_prediction, states=_states)
533529

534-
return _time, _total, _encoded, _hypothesis
530+
return _time, _hypothesis
535531

536-
_, _, _, hypothesis = tf.while_loop(
532+
time, hypothesis = tf.while_loop(
537533
condition, body,
538-
loop_vars=[time, total, encoded, hypothesis],
539-
parallel_iterations=parallel_iterations,
540-
swap_memory=swap_memory
534+
loop_vars=[time, hypothesis], parallel_iterations=parallel_iterations, swap_memory=swap_memory
541535
)
542536

543537
return Hypothesis(index=hypothesis.index, prediction=hypothesis.prediction.stack(), states=hypothesis.states)
@@ -570,37 +564,32 @@ def _perform_beam_search_batch(self,
570564
lm: bool = False,
571565
parallel_iterations: int = 10,
572566
swap_memory: bool = False):
573-
total_batch, total_time, _ = shape_list(encoded)
574-
batch = tf.constant(0, dtype=tf.int32)
567+
with tf.name_scope(f"{self.name}_perform_beam_search_batch"):
568+
total_batch = tf.shape(encoded)[0]
569+
batch = tf.constant(0, dtype=tf.int32)
575570

576-
decoded = tf.TensorArray(
577-
dtype=tf.int32, size=total_batch, dynamic_size=False,
578-
clear_after_read=False, element_shape=None
579-
)
571+
decoded = tf.TensorArray(
572+
dtype=tf.int32, size=total_batch, dynamic_size=False,
573+
clear_after_read=False, element_shape=None
574+
)
580575

581-
def condition(batch, _): return tf.less(batch, total_batch)
576+
def condition(batch, _): return tf.less(batch, total_batch)
582577

583-
def body(batch, decoded):
584-
hypothesis = self._perform_beam_search(
585-
encoded[batch], encoded_length[batch], lm,
586-
parallel_iterations=parallel_iterations, swap_memory=swap_memory
587-
)
588-
prediction = tf.pad(
589-
hypothesis.prediction,
590-
paddings=[[0, 2 * (total_time - encoded_length[batch])]],
591-
mode="CONSTANT", constant_values=self.text_featurizer.blank
578+
def body(batch, decoded):
579+
hypothesis = self._perform_beam_search(
580+
encoded[batch], encoded_length[batch], lm,
581+
parallel_iterations=parallel_iterations, swap_memory=swap_memory
582+
)
583+
decoded = decoded.write(batch, hypothesis.prediction)
584+
return batch + 1, decoded
585+
586+
batch, decoded = tf.while_loop(
587+
condition, body, loop_vars=[batch, decoded],
588+
parallel_iterations=parallel_iterations, swap_memory=True,
592589
)
593-
decoded = decoded.write(batch, prediction)
594-
return batch + 1, decoded
595-
596-
batch, decoded = tf.while_loop(
597-
condition, body,
598-
loop_vars=[batch, decoded],
599-
parallel_iterations=parallel_iterations,
600-
swap_memory=True,
601-
)
602590

603-
return self.text_featurizer.iextract(decoded.stack())
591+
decoded = pad_prediction_tfarray(decoded, blank=self.text_featurizer.blank)
592+
return self.text_featurizer.iextract(decoded.stack())
604593

605594
def _perform_beam_search(self,
606595
encoded: tf.Tensor,
@@ -640,7 +629,7 @@ def initialize_beam(dynamic=False):
640629
B = BeamHypothesis(
641630
score=B.score.write(0, 0.0),
642631
indices=B.indices.write(0, self.text_featurizer.blank),
643-
prediction=B.prediction.write(0, tf.ones([total * 2], dtype=tf.int32) * self.text_featurizer.blank),
632+
prediction=B.prediction.write(0, tf.ones([total], dtype=tf.int32) * self.text_featurizer.blank),
644633
states=B.states.write(0, self.predict_net.get_initial_state())
645634
)
646635

@@ -651,7 +640,8 @@ def body(time, total, B):
651640
A = BeamHypothesis(
652641
score=A.score.unstack(B.score.stack()),
653642
indices=A.indices.unstack(B.indices.stack()),
654-
prediction=A.prediction.unstack(B.prediction.stack()),
643+
prediction=A.prediction.unstack(
644+
pad_prediction_tfarray(B.prediction, blank=self.text_featurizer.blank).stack()),
655645
states=A.states.unstack(B.states.stack()),
656646
)
657647
A_i = tf.constant(0, tf.int32)
@@ -666,7 +656,8 @@ def beam_body(beam, beam_width, A, A_i, B):
666656
y_hat_score, y_hat_score_index = tf.math.top_k(A.score.stack(), k=1, sorted=True)
667657
y_hat_score = y_hat_score[0]
668658
y_hat_index = tf.gather_nd(A.indices.stack(), y_hat_score_index)
669-
y_hat_prediction = tf.gather_nd(A.prediction.stack(), y_hat_score_index)
659+
y_hat_prediction = tf.gather_nd(
660+
pad_prediction_tfarray(A.prediction, blank=self.text_featurizer.blank).stack(), y_hat_score_index)
670661
y_hat_states = tf.gather_nd(A.states.stack(), y_hat_score_index)
671662

672663
# remove y_hat from A
@@ -676,7 +667,8 @@ def beam_body(beam, beam_width, A, A_i, B):
676667
A = BeamHypothesis(
677668
score=A.score.unstack(tf.gather_nd(A.score.stack(), remain_indices)),
678669
indices=A.indices.unstack(tf.gather_nd(A.indices.stack(), remain_indices)),
679-
prediction=A.prediction.unstack(tf.gather_nd(A.prediction.stack(), remain_indices)),
670+
prediction=A.prediction.unstack(tf.gather_nd(
671+
pad_prediction_tfarray(A.prediction, blank=self.text_featurizer.blank).stack(), remain_indices)),
680672
states=A.states.unstack(tf.gather_nd(A.states.stack(), remain_indices)),
681673
)
682674
A_i = tf.cond(tf.equal(A_i, 0), true_fn=lambda: A_i, false_fn=lambda: A_i - 1)
@@ -752,14 +744,15 @@ def false_fn():
752744
)
753745

754746
scores = B.score.stack()
747+
prediction = pad_prediction_tfarray(B.prediction, blank=self.text_featurizer.blank).stack()
755748
if self.text_featurizer.decoder_config.norm_score:
756-
prediction_lengths = count_non_blank(B.prediction.stack(), blank=self.text_featurizer.blank, axis=1)
749+
prediction_lengths = count_non_blank(prediction, blank=self.text_featurizer.blank, axis=1)
757750
scores /= tf.cast(prediction_lengths, dtype=scores.dtype)
758751

759752
y_hat_score, y_hat_score_index = tf.math.top_k(scores, k=1)
760753
y_hat_score = y_hat_score[0]
761754
y_hat_index = tf.gather_nd(B.indices.stack(), y_hat_score_index)
762-
y_hat_prediction = tf.gather_nd(B.prediction.stack(), y_hat_score_index)
755+
y_hat_prediction = tf.gather_nd(prediction, y_hat_score_index)
763756
y_hat_states = tf.gather_nd(B.states.stack(), y_hat_score_index)
764757

765758
return Hypothesis(index=y_hat_index, prediction=y_hat_prediction, states=y_hat_states)

tensorflow_asr/utils/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,42 @@ def has_gpu_or_tpu():
167167
tpus = tf.config.list_logical_devices("TPU")
168168
if len(gpus) == 0 and len(tpus) == 0: return False
169169
return True
170+
171+
172+
def find_max_length_prediction_tfarray(tfarray: tf.TensorArray) -> tf.Tensor:
173+
with tf.name_scope("find_max_length_prediction_tfarray"):
174+
index = tf.constant(0, dtype=tf.int32)
175+
total = tfarray.size()
176+
max_length = tf.constant(0, dtype=tf.int32)
177+
178+
def condition(index, _): return tf.less(index, total)
179+
180+
def body(index, max_length):
181+
prediction = tfarray.read(index)
182+
length = tf.shape(prediction)[0]
183+
max_length = tf.where(tf.greater(length, max_length), length, max_length)
184+
return index + 1, max_length
185+
186+
index, max_length = tf.while_loop(condition, body, loop_vars=[index, max_length], swap_memory=False)
187+
return max_length
188+
189+
190+
def pad_prediction_tfarray(tfarray: tf.TensorArray, blank: int or tf.Tensor) -> tf.TensorArray:
191+
with tf.name_scope("pad_prediction_tfarray"):
192+
index = tf.constant(0, dtype=tf.int32)
193+
total = tfarray.size()
194+
max_length = find_max_length_prediction_tfarray(tfarray)
195+
196+
def condition(index, _): return tf.less(index, total)
197+
198+
def body(index, tfarray):
199+
prediction = tfarray.read(index)
200+
prediction = tf.pad(
201+
prediction, paddings=[[0, max_length - tf.shape(prediction)[0]]],
202+
mode="CONSTANT", constant_values=blank
203+
)
204+
tfarray = tfarray.write(index, prediction)
205+
return index + 1, tfarray
206+
207+
index, tfarray = tf.while_loop(condition, body, loop_vars=[index, tfarray], swap_memory=False)
208+
return tfarray

0 commit comments

Comments
 (0)