|
48 | 48 | from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer |
49 | 49 | from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer, SubwordFeaturizer, SentencePieceFeaturizer |
50 | 50 | from tensorflow_asr.models.transducer.conformer import Conformer |
| 51 | +from tensorflow_asr.utils.data_util import create_inputs |
51 | 52 |
|
52 | 53 | config = Config(args.config) |
53 | 54 | speech_featurizer = TFSpeechFeaturizer(config.speech_config) |
|
70 | 71 |
|
71 | 72 | signal = read_raw_audio(args.filename) |
72 | 73 | features = speech_featurizer.tf_extract(signal) |
73 | | -input_length = math_util.get_reduced_length(tf.shape(features)[0], conformer.time_reduction_factor) |
| 74 | +input_length = tf.shape(features)[0] |
74 | 75 |
|
75 | 76 | if args.beam_width: |
76 | | - transcript = conformer.recognize_beam(features[None, ...], input_length[None, ...]) |
77 | | - logger.info("Transcript:", transcript[0].numpy().decode("UTF-8")) |
| 77 | + inputs = create_inputs(features[None, ...], input_length[None, ...]) |
| 78 | + transcript = conformer.recognize_beam(inputs) |
| 79 | + logger.info(f"Transcript: {transcript[0].numpy().decode('UTF-8')}") |
78 | 80 | elif args.timestamp: |
79 | 81 | transcript, stime, etime, _, _ = conformer.recognize_tflite_with_timestamp( |
80 | 82 | signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state()) |
81 | | - logger.info("Transcript:", transcript) |
82 | | - logger.info("Start time:", stime) |
83 | | - logger.info("End time:", etime) |
| 83 | + logger.info(f"Transcript: {transcript}") |
| 84 | + logger.info(f"Start time: {stime}") |
| 85 | + logger.info(f"End time: {etime}") |
84 | 86 | else: |
85 | | - transcript, _, _ = conformer.recognize_tflite( |
| 87 | + code_points, _, _ = conformer.recognize_tflite( |
86 | 88 | signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state()) |
87 | | - logger.info("Transcript:", tf.strings.unicode_encode(transcript, "UTF-8").numpy().decode("UTF-8")) |
| 89 | + transcript = tf.strings.unicode_encode(code_points, 'UTF-8').numpy().decode('UTF-8') |
| 90 | + logger.info(f"Transcript: {transcript}") |
0 commit comments