Skip to content

Commit c6eeb03

Browse files
authored
Merge pull request #204 from ebraraktas/fix/conformer-demo
Fix/conformer demo
2 parents 8ee4007 + 9134eb4 commit c6eeb03

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

examples/demonstration/conformer.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
4949
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer, SubwordFeaturizer, SentencePieceFeaturizer
5050
from tensorflow_asr.models.transducer.conformer import Conformer
51+
from tensorflow_asr.utils.data_util import create_inputs
5152

5253
config = Config(args.config)
5354
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
@@ -70,18 +71,20 @@
7071

7172
signal = read_raw_audio(args.filename)
7273
features = speech_featurizer.tf_extract(signal)
73-
input_length = math_util.get_reduced_length(tf.shape(features)[0], conformer.time_reduction_factor)
74+
input_length = tf.shape(features)[0]
7475

7576
if args.beam_width:
76-
transcript = conformer.recognize_beam(features[None, ...], input_length[None, ...])
77-
logger.info("Transcript:", transcript[0].numpy().decode("UTF-8"))
77+
inputs = create_inputs(features[None, ...], input_length[None, ...])
78+
transcript = conformer.recognize_beam(inputs)
79+
logger.info(f"Transcript: {transcript[0].numpy().decode('UTF-8')}")
7880
elif args.timestamp:
7981
transcript, stime, etime, _, _ = conformer.recognize_tflite_with_timestamp(
8082
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state())
81-
logger.info("Transcript:", transcript)
82-
logger.info("Start time:", stime)
83-
logger.info("End time:", etime)
83+
logger.info(f"Transcript: {transcript}")
84+
logger.info(f"Start time: {stime}")
85+
logger.info(f"End time: {etime}")
8486
else:
85-
transcript, _, _ = conformer.recognize_tflite(
87+
code_points, _, _ = conformer.recognize_tflite(
8688
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state())
87-
logger.info("Transcript:", tf.strings.unicode_encode(transcript, "UTF-8").numpy().decode("UTF-8"))
89+
transcript = tf.strings.unicode_encode(code_points, 'UTF-8').numpy().decode('UTF-8')
90+
logger.info(f"Transcript: {transcript}")

tensorflow_asr/models/transducer/transducer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False):
604604
"""
605605
RNN Transducer Beam Search
606606
Args:
607-
features (tf.Tensor): a batch of padded extracted features
607+
inputs (Dict[str, tf.Tensor]): Input dictionary containing "inputs" and "inputs_length"
608608
lm (bool, optional): whether to use language model. Defaults to False.
609609
610610
Returns:

0 commit comments

Comments
 (0)