Skip to content

Commit 1bad037

Browse files
authored
Merge pull request #83 from TensorSpeech/fix/tflite
Fix TFLite Conversion for Transducer Greedy
2 parents 6d70eab + 6df0393 commit 1bad037

File tree

3 files changed

+19
-41
lines changed

3 files changed

+19
-41
lines changed

examples/conformer/tflite_subword_conformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@
6464

6565
concrete_func = conformer.make_tflite_function(greedy=True).get_concrete_function()
6666
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
67+
converter.experimental_new_converter = True
6768
converter.optimizations = [tf.lite.Optimize.DEFAULT]
68-
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
69-
tf.lite.OpsSet.SELECT_TF_OPS]
69+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
7070
tflite_model = converter.convert()
7171

7272
if not os.path.exists(os.path.dirname(args.output)):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
setuptools.setup(
3939
name="TensorFlowASR",
40-
version="0.5.1",
40+
version="0.5.2",
4141
author="Huy Le Nguyen",
4242
author_email="[email protected]",
4343
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/models/transducer.py

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -417,73 +417,51 @@ def __perform_greedy(self,
417417
with tf.name_scope(f"{self.name}_greedy"):
418418
time = tf.constant(0, dtype=tf.int32)
419419
total = encoded_length
420-
# Initialize prediction with a blank
421-
# Prediction can not be longer than the encoded of audio plus blank
422-
prediction = tf.TensorArray(
423-
dtype=tf.int32,
424-
size=(total + 1),
425-
dynamic_size=False,
426-
element_shape=tf.TensorShape([]),
427-
clear_after_read=False
428-
)
429420

430421
hypothesis = Hypothesis(
431422
index=tf.constant(0, dtype=tf.int32),
432-
prediction=prediction.write(0, predicted),
423+
prediction=tf.ones([total + 1], dtype=tf.int32) * self.text_featurizer.blank,
433424
states=states
434425
)
435426

436427
def condition(time, total, encoded, hypothesis): return tf.less(time, total)
437428

438429
def body(time, total, encoded, hypothesis):
430+
predicted = tf.gather_nd(hypothesis.prediction, tf.expand_dims(hypothesis.index, axis=-1))
431+
439432
ytu, new_states = self.decoder_inference(
440433
# avoid using [index] in tflite
441434
encoded=tf.gather_nd(encoded, tf.expand_dims(time, axis=-1)),
442-
predicted=hypothesis.prediction.read(hypothesis.index),
435+
predicted=predicted,
443436
states=hypothesis.states
444437
)
445-
char = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []
446-
447-
index, char, new_states = tf.cond(
448-
tf.equal(char, self.text_featurizer.blank),
449-
true_fn=lambda: (
450-
hypothesis.index,
451-
hypothesis.prediction.read(hypothesis.index),
452-
hypothesis.states
453-
),
454-
false_fn=lambda: (
455-
hypothesis.index + 1,
456-
char,
457-
new_states
458-
)
438+
new_predicted = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []
439+
440+
index, new_predicted, new_states = tf.cond(
441+
tf.equal(new_predicted, self.text_featurizer.blank),
442+
true_fn=lambda: (hypothesis.index, predicted, hypothesis.states),
443+
false_fn=lambda: (hypothesis.index + 1, new_predicted, new_states)
459444
)
460445

461446
hypothesis = Hypothesis(
462447
index=index,
463-
prediction=hypothesis.prediction.write(index, char),
448+
prediction=tf.tensor_scatter_nd_update(
449+
hypothesis.prediction,
450+
indices=tf.reshape(index, [1, 1]),
451+
updates=tf.expand_dims(new_predicted, axis=-1)
452+
),
464453
states=new_states
465454
)
466455

467456
return time + 1, total, encoded, hypothesis
468457

469458
time, total, encoded, hypothesis = tf.while_loop(
470-
condition,
471-
body,
459+
condition, body,
472460
loop_vars=(time, total, encoded, hypothesis),
473461
parallel_iterations=parallel_iterations,
474462
swap_memory=swap_memory
475463
)
476464

477-
# Gather predicted sequence
478-
hypothesis = Hypothesis(
479-
index=hypothesis.index,
480-
prediction=tf.gather_nd(
481-
params=hypothesis.prediction.stack(),
482-
indices=tf.expand_dims(tf.range(hypothesis.index + 1), axis=-1)
483-
),
484-
states=hypothesis.states
485-
)
486-
487465
return hypothesis
488466

489467
# -------------------------------- BEAM SEARCH -------------------------------------

0 commit comments

Comments
 (0)