From d51a0a8334720b2deb011b6af92defa46fdea38b Mon Sep 17 00:00:00 2001 From: Viraj Karandikar Date: Thu, 22 Feb 2024 19:43:05 +0530 Subject: [PATCH 1/4] asr: minor updates to argument names and logging --- riva/client/argparse_utils.py | 16 ++++++---- riva/client/asr.py | 37 ++++++++++++++++-------- riva/client/audio_io.py | 10 +++---- scripts/asr/riva_streaming_asr_client.py | 29 ++++++++++--------- scripts/asr/transcribe_file.py | 29 +++++++++---------- scripts/asr/transcribe_mic.py | 20 ++++--------- 6 files changed, 74 insertions(+), 67 deletions(-) diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index 4e0dea87..2a620ced 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -5,7 +5,10 @@ def add_asr_config_argparse_parameters( - parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, word_time_offsets: bool = False + parser: argparse.ArgumentParser, + max_alternatives: bool = False, + profanity_filter: bool = False, + word_time_offsets: bool = False, ) -> argparse.ArgumentParser: if word_time_offsets: parser.add_argument( @@ -20,11 +23,11 @@ def add_asr_config_argparse_parameters( ) if profanity_filter: parser.add_argument( - "--profanity-filter", - default=False, - action='store_true', - help="Flag that controls the profanity filtering in the generated transcripts", - ) + "--profanity-filter", + default=False, + action='store_true', + help="Flag that controls the profanity filtering in the generated transcripts", + ) parser.add_argument( "--automatic-punctuation", default=False, @@ -38,6 +41,7 @@ def add_asr_config_argparse_parameters( help="If specified, text inverse normalization will be applied", ) parser.add_argument("--language-code", default="en-US", help="Language code of the model to be used.") + parser.add_argument("--model-name", default="", help="Name of the model to be used to be used.") parser.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding.") parser.add_argument( "--boosted-lm-score", type=float, default=4.0, help="Value by which to boost words when decoding." diff --git a/riva/client/asr.py b/riva/client/asr.py index 3a0da418..5f699838 100644 --- a/riva/client/asr.py +++ b/riva/client/asr.py @@ -30,7 +30,7 @@ def get_wav_file_parameters(input_file: Union[str, os.PathLike]) -> Dict[str, Un 'duration': nframes / rate, 'nchannels': wf.getnchannels(), 'sampwidth': wf.getsampwidth(), - 'data_offset': wf.getfp().size_read + wf.getfp().offset + 'data_offset': wf.getfp().size_read + wf.getfp().offset, } except: # Not a WAV file @@ -46,11 +46,11 @@ class AudioChunkFileIterator: def __init__( self, input_file: Union[str, os.PathLike], - chunk_n_frames: int, + chunk_duration_ms: int, delay_callback: Optional[Callable[[bytes, float], None]] = None, ) -> None: self.input_file: Path = Path(input_file).expanduser() - self.chunk_n_frames = chunk_n_frames + self.chunk_duration_ms = chunk_duration_ms self.delay_callback = delay_callback self.file_parameters = get_wav_file_parameters(self.input_file) self.file_object: Optional[typing.BinaryIO] = open(str(self.input_file), 'rb') @@ -75,16 +75,21 @@ def __iter__(self): def __next__(self) -> bytes: if self.file_parameters: - data = self.file_object.read(self.chunk_n_frames * self.file_parameters['sampwidth'] * self.file_parameters['nchannels']) + num_frames = int(self.chunk_duration_ms * self.file_parameters['framerate'] / 1000) + data = self.file_object.read( + num_frames * self.file_parameters['sampwidth'] * self.file_parameters['nchannels'] + ) else: - data = self.file_object.read(self.chunk_n_frames) + # Fixed chunk size when file_parameters is not available + data = self.file_object.read(8192) if not data: self.close() raise StopIteration if self.delay_callback is not None: offset = self.file_parameters['data_offset'] if self.first_buffer else 0 self.delay_callback( - data[offset:], (len(data) - offset) / self.file_parameters['sampwidth'] / self.file_parameters['framerate'] + data[offset:], + (len(data) - offset) / self.file_parameters['sampwidth'] / self.file_parameters['framerate'], ) self.first_buffer = False return data @@ -104,8 +109,7 @@ def add_word_boosting_to_config( def add_audio_file_specs_to_config( - config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig], - audio_file: Union[str, os.PathLike], + config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig], audio_file: Union[str, os.PathLike], ) -> None: inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config wav_parameters = get_wav_file_parameters(audio_file) @@ -114,10 +118,7 @@ def add_audio_file_specs_to_config( inner_config.audio_channel_count = wav_parameters['nchannels'] -def add_speaker_diarization_to_config( - config: Union[rasr.RecognitionConfig], - diarization_enable: bool, -) -> None: +def add_speaker_diarization_to_config(config: Union[rasr.RecognitionConfig], diarization_enable: bool,) -> None: inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config if diarization_enable: diarization_config = rasr.SpeakerDiarizationConfig(enable_speaker_diarization=True) @@ -129,6 +130,7 @@ def add_speaker_diarization_to_config( def print_streaming( responses: Iterable[rasr.StreamingRecognizeResponse], + input_file: str = None, output_file: Optional[Union[Union[os.PathLike, str, TextIO], List[Union[os.PathLike, str, TextIO]]]] = None, additional_info: str = 'no', word_time_offsets: bool = False, @@ -194,6 +196,10 @@ def print_streaming( output_file[i] = Path(elem).expanduser().open(file_mode) start_time = time.time() # used in 'time` additional_info num_chars_printed = 0 # used in 'no' additional_info + final_transcript = "" # for printing best final transcript + if input_file: + for f in output_file: + f.write(f"File: {input_file}\n") for response in responses: if not response.results: continue @@ -204,6 +210,7 @@ def print_streaming( transcript = result.alternatives[0].transcript if additional_info == 'no': if result.is_final: + final_transcript += transcript if show_intermediate: overwrite_chars = ' ' * (num_chars_printed - len(transcript)) for i, f in enumerate(output_file): @@ -221,6 +228,7 @@ def print_streaming( partial_transcript += transcript elif additional_info == 'time': if result.is_final: + final_transcript += transcript for i, alternative in enumerate(result.alternatives): for f in output_file: f.write( @@ -239,6 +247,7 @@ def print_streaming( partial_transcript += transcript else: # additional_info == 'confidence' if result.is_final: + final_transcript += transcript for f in output_file: f.write(f'## {transcript}\n') f.write(f'Confidence: {result.alternatives[0].confidence:9.4f}\n') @@ -259,6 +268,9 @@ def print_streaming( else: for f in output_file: f.write('----\n') + for f in output_file: + f.write(f"Final transcripts:\n") + f.write(f"0 : {final_transcript}\n") finally: for fo, elem in zip(file_opened, output_file): if fo: @@ -284,6 +296,7 @@ def streaming_request_generator( class ASRService: """Provides streaming and offline recognition services. Calls gRPC stubs with authentication metadata.""" + def __init__(self, auth: Auth) -> None: """ Initializes an instance of the class. diff --git a/riva/client/audio_io.py b/riva/client/audio_io.py index ea432793..2bdff93a 100644 --- a/riva/client/audio_io.py +++ b/riva/client/audio_io.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: MIT import queue -from typing import Dict, Union, Optional +from typing import Dict, Optional, Union import pyaudio @@ -10,9 +10,9 @@ class MicrophoneStream: """Opens a recording stream as responses yielding the audio chunks.""" - def __init__(self, rate: int, chunk: int, device: int = None) -> None: + def __init__(self, rate: int, chunk_duration_ms: int, device: int = None) -> None: self._rate = rate - self._chunk = chunk + self._chunk = int(chunk_duration_ms * rate / 1000) self._device = device # Create a thread-safe buffer of audio data @@ -115,9 +115,7 @@ def list_input_devices() -> None: class SoundCallBack: - def __init__( - self, output_device_index: Optional[int], sampwidth: int, nchannels: int, framerate: int, - ) -> None: + def __init__(self, output_device_index: Optional[int], sampwidth: int, nchannels: int, framerate: int,) -> None: self.pa = pyaudio.PyAudio() self.stream = self.pa.open( output_device_index=output_device_index, diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index 6c7785ec..83249652 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -10,8 +10,8 @@ from typing import Union import riva.client -from riva.client.asr import get_wav_file_parameters from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters +from riva.client.asr import get_wav_file_parameters def parse_args() -> argparse.Namespace: @@ -23,7 +23,7 @@ def parse_args() -> argparse.Namespace: "which names follow a format `output_.txt`.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--num-clients", default=1, type=int, help="Number of client threads.") + parser.add_argument("--num-parallel-requests", default=1, type=int, help="Number of client threads.") parser.add_argument("--num-iterations", default=1, type=int, help="Number of iterations over the file.") parser.add_argument( "--input-file", required=True, type=str, help="Name of the WAV file with LINEAR_PCM encoding to transcribe." @@ -34,11 +34,14 @@ def parse_args() -> argparse.Namespace: help="Option to simulate realtime transcription. Audio fragments are sent to a server at a pace that mimics " "normal speech.", ) + parser.add_argument("--chunk-duration-ms", type=int, default=100, help="Chunk duration in milliseconds.") parser.add_argument( - "--file-streaming-chunk", type=int, default=1600, help="Number of frames in one chunk sent to server." + "--interim-results", default=False, action='store_true', help="Print intermediate transcripts", ) parser = add_connection_argparse_parameters(parser) - parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True) + parser = add_asr_config_argparse_parameters( + parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True + ) args = parser.parse_args() if args.max_alternatives < 1: parser.error("`--max-alternatives` must be greater than or equal to 1") @@ -60,24 +63,23 @@ def streaming_transcription_worker( enable_automatic_punctuation=args.automatic_punctuation, verbatim_transcripts=not args.no_verbatim_transcripts, enable_word_time_offsets=args.word_time_offsets, + model=args.model_name, ), - interim_results=True, + interim_results=args.interim_results, ) riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score) for _ in range(args.num_iterations): with riva.client.AudioChunkFileIterator( args.input_file, - args.file_streaming_chunk, + args.chunk_duration_ms, delay_callback=riva.client.sleep_audio_length if args.simulate_realtime else None, ) as audio_chunk_iterator: riva.client.print_streaming( responses=asr_service.streaming_response_generator( - audio_chunks=audio_chunk_iterator, - streaming_config=config, + audio_chunks=audio_chunk_iterator, streaming_config=config, ), + input_file=args.input_file, output_file=output_file, - additional_info='time', - file_mode='a', word_time_offsets=args.word_time_offsets, ) except BaseException as e: @@ -87,12 +89,12 @@ def streaming_transcription_worker( def main() -> None: args = parse_args() - print("Number of clients:", args.num_clients) + print("Number of clients:", args.num_parallel_requests) print("Number of iteration:", args.num_iterations) print("Input file:", args.input_file) threads = [] exception_queue = queue.Queue() - for i in range(args.num_clients): + for i in range(args.num_parallel_requests): t = Thread(target=streaming_transcription_worker, args=[args, f"output_{i:d}.txt", i, exception_queue]) t.start() threads.append(t) @@ -112,7 +114,8 @@ def main() -> None: if all_dead: break time.sleep(0.05) - print(str(args.num_clients), "threads done, output written to output_.txt") + for i in range(args.num_parallel_requests): + print(f"Thread {i} done, output written to output_{i}.txt") if __name__ == "__main__": diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index 8c6a6238..b7053ed5 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -18,14 +18,14 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--input-file", help="A path to a local file to stream.") parser.add_argument("--list-devices", action="store_true", help="List output devices indices") parser.add_argument( - "--show-intermediate", action="store_true", help="Show intermediate transcripts as they are available." + "--interim-results", default=False, action='store_true', help="Print intermediate transcripts", ) parser.add_argument( "--output-device", type=int, default=None, help="Output audio device to use for playing audio simultaneously with transcribing. If this parameter is " - "provided, then you do not have to `--play-audio` option." + "provided, then you do not have to `--play-audio` option.", ) parser.add_argument( "--play-audio", @@ -33,12 +33,7 @@ def parse_args() -> argparse.Namespace: help="Whether to play input audio simultaneously with transcribing. If `--output-device` is not provided, " "then the default output audio device will be used.", ) - parser.add_argument( - "--file-streaming-chunk", - type=int, - default=1600, - help="A maximum number of frames in one chunk sent to server.", - ) + parser.add_argument("--chunk-duration-ms", type=int, default=100, help="Chunk duration in milliseconds.") parser.add_argument( "--simulate-realtime", action='store_true', @@ -49,7 +44,9 @@ def parse_args() -> argparse.Namespace: "--print-confidence", action="store_true", help="Whether to print stability and confidence of transcript." ) parser = add_connection_argparse_parameters(parser) - parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True) + parser = add_asr_config_argparse_parameters( + parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True + ) args = parser.parse_args() if not args.list_devices and args.input_file is None: parser.error( @@ -71,12 +68,14 @@ def main() -> None: config = riva.client.StreamingRecognitionConfig( config=riva.client.RecognitionConfig( language_code=args.language_code, - max_alternatives=1, + max_alternatives=args.max_alternatives, profanity_filter=args.profanity_filter, enable_automatic_punctuation=args.automatic_punctuation, verbatim_transcripts=not args.no_verbatim_transcripts, + enable_word_time_offsets=args.word_time_offsets, + model=args.model_name, ), - interim_results=True, + interim_results=args.interim_results, ) riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score) sound_callback = None @@ -90,14 +89,14 @@ def main() -> None: else: delay_callback = riva.client.sleep_audio_length if args.simulate_realtime else None with riva.client.AudioChunkFileIterator( - args.input_file, args.file_streaming_chunk, delay_callback, + args.input_file, args.chunk_duration_ms, delay_callback, ) as audio_chunk_iterator: riva.client.print_streaming( responses=asr_service.streaming_response_generator( - audio_chunks=audio_chunk_iterator, - streaming_config=config, + audio_chunks=audio_chunk_iterator, streaming_config=config, ), - show_intermediate=args.show_intermediate, + input_file=args.input_file, + show_intermediate=args.interim_results, additional_info="confidence" if args.print_confidence else "no", ) finally: diff --git a/scripts/asr/transcribe_mic.py b/scripts/asr/transcribe_mic.py index b1e231b9..cecc8979 100644 --- a/scripts/asr/transcribe_mic.py +++ b/scripts/asr/transcribe_mic.py @@ -4,9 +4,8 @@ import argparse import riva.client -from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters - import riva.client.audio_io +from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters def parse_args() -> argparse.Namespace: @@ -26,12 +25,7 @@ def parse_args() -> argparse.Namespace: help="A number of frames per second in audio streamed from a microphone.", default=16000, ) - parser.add_argument( - "--file-streaming-chunk", - type=int, - default=1600, - help="A maximum number of frames in a audio chunk sent to server.", - ) + parser.add_argument("--chunk-duration-ms", type=int, default=100, help="Chunk duration in milliseconds.") args = parser.parse_args() return args @@ -51,21 +45,17 @@ def main() -> None: profanity_filter=args.profanity_filter, enable_automatic_punctuation=args.automatic_punctuation, verbatim_transcripts=not args.no_verbatim_transcripts, - sample_rate_hertz=args.sample_rate_hz, - audio_channel_count=1, + model=args.model_name, ), interim_results=True, ) riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score) with riva.client.audio_io.MicrophoneStream( - args.sample_rate_hz, - args.file_streaming_chunk, - device=args.input_device, + args.sample_rate_hz, args.chunk_duration_ms, device=args.input_device, ) as audio_chunk_iterator: riva.client.print_streaming( responses=asr_service.streaming_response_generator( - audio_chunks=audio_chunk_iterator, - streaming_config=config, + audio_chunks=audio_chunk_iterator, streaming_config=config, ), show_intermediate=True, ) From de755af4510cfd32da777080f784acdf3d3394c0 Mon Sep 17 00:00:00 2001 From: Viraj Karandikar Date: Thu, 22 Feb 2024 22:30:03 +0530 Subject: [PATCH 2/4] support word boosting from file --- riva/client/argparse_utils.py | 15 +++-- riva/client/asr.py | 15 +++-- scripts/asr/riva_streaming_asr_client.py | 2 +- scripts/asr/transcribe_file.py | 73 +++++++++++++++--------- 4 files changed, 67 insertions(+), 38 deletions(-) diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index 2a620ced..3618fca8 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MIT import argparse +from pathlib import Path def add_asr_config_argparse_parameters( @@ -35,16 +36,18 @@ def add_asr_config_argparse_parameters( help="Flag that controls if transcript should be automatically punctuated", ) parser.add_argument( - "--no-verbatim-transcripts", - default=False, - action='store_true', - help="If specified, text inverse normalization will be applied", + "--verbatim-transcripts", + default=True, + action='store_false', + help="True returns text exactly as it was said. False applies Inverse text normalization", ) parser.add_argument("--language-code", default="en-US", help="Language code of the model to be used.") parser.add_argument("--model-name", default="", help="Name of the model to be used to be used.") - parser.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding.") parser.add_argument( - "--boosted-lm-score", type=float, default=4.0, help="Value by which to boost words when decoding." + "--boosted-words-file", default=None, type=Path, help="File with a list of words to boost. One line per word." + ) + parser.add_argument( + "--boosted-words-score", type=float, default=4.0, help="Score by which to boost the boosted words." ) parser.add_argument( "--speaker-diarization", diff --git a/riva/client/asr.py b/riva/client/asr.py index 5f699838..77ce524e 100644 --- a/riva/client/asr.py +++ b/riva/client/asr.py @@ -97,14 +97,19 @@ def __next__(self) -> bytes: def add_word_boosting_to_config( config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig], - boosted_lm_words: Optional[List[str]], - boosted_lm_score: float, + boosted_words_file: Union[str, os.PathLike], + boosted_words_score: float, ) -> None: inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config - if boosted_lm_words is not None: + boosted_words = [] + if boosted_words_file: + with open(boosted_words_file) as f: + boosted_words = f.read().splitlines() + + if boosted_words is not None: speech_context = rasr.SpeechContext() - speech_context.phrases.extend(boosted_lm_words) - speech_context.boost = boosted_lm_score + speech_context.phrases.extend(boosted_words) + speech_context.boost = boosted_words_score inner_config.speech_contexts.append(speech_context) diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index 83249652..65fcb6e4 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -26,7 +26,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--num-parallel-requests", default=1, type=int, help="Number of client threads.") parser.add_argument("--num-iterations", default=1, type=int, help="Number of iterations over the file.") parser.add_argument( - "--input-file", required=True, type=str, help="Name of the WAV file with LINEAR_PCM encoding to transcribe." + "--input-file", required=True, type=Path, help="Name of the WAV file with LINEAR_PCM encoding to transcribe." ) parser.add_argument( "--simulate-realtime", diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index b7053ed5..157e2846 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: MIT import argparse +import json +from pathlib import Path import riva.client from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters @@ -15,7 +17,12 @@ def parse_args() -> argparse.Namespace: "`--play-audio` or `--output-device`.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--input-file", help="A path to a local file to stream.") + parser.add_argument( + "--input-file", + required=True, + type=Path, + help="A path to a local file to stream or a JSONL file containing list of files. JSONL file should contain JSON entry on each line, for example: {'audio_filepath': 'audio.wav'} ", + ) parser.add_argument("--list-devices", action="store_true", help="List output devices indices") parser.add_argument( "--interim-results", default=False, action='store_true', help="Print intermediate transcripts", @@ -63,6 +70,17 @@ def main() -> None: if args.list_devices: riva.client.audio_io.list_output_devices() return + input_files = [] + if args.input_file.suffix == ".json": + with open(args.input_file) as f: + lines = f.read().splitlines() + for line in lines: + data = json.loads(line) + if "audio_filepath" in data: + input_files.append(data["audio_filepath"]) + else: + input_files = [args.input_file] + auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) asr_service = riva.client.ASRService(auth) config = riva.client.StreamingRecognitionConfig( @@ -71,37 +89,40 @@ def main() -> None: max_alternatives=args.max_alternatives, profanity_filter=args.profanity_filter, enable_automatic_punctuation=args.automatic_punctuation, - verbatim_transcripts=not args.no_verbatim_transcripts, + verbatim_transcripts=args.verbatim_transcripts, enable_word_time_offsets=args.word_time_offsets, model=args.model_name, ), interim_results=args.interim_results, ) - riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score) + riva.client.add_word_boosting_to_config(config, args.boosted_words_file, args.boosted_words_score) sound_callback = None - try: - if args.play_audio or args.output_device is not None: - wp = riva.client.get_wav_file_parameters(args.input_file) - sound_callback = riva.client.audio_io.SoundCallBack( - args.output_device, wp['sampwidth'], wp['nchannels'], wp['framerate'], - ) - delay_callback = sound_callback - else: - delay_callback = riva.client.sleep_audio_length if args.simulate_realtime else None - with riva.client.AudioChunkFileIterator( - args.input_file, args.chunk_duration_ms, delay_callback, - ) as audio_chunk_iterator: - riva.client.print_streaming( - responses=asr_service.streaming_response_generator( - audio_chunks=audio_chunk_iterator, streaming_config=config, - ), - input_file=args.input_file, - show_intermediate=args.interim_results, - additional_info="confidence" if args.print_confidence else "no", - ) - finally: - if sound_callback is not None and sound_callback.opened: - sound_callback.close() + + for file in input_files: + try: + if args.play_audio or args.output_device is not None: + wp = riva.client.get_wav_file_parameters(file) + sound_callback = riva.client.audio_io.SoundCallBack( + args.output_device, wp['sampwidth'], wp['nchannels'], wp['framerate'], + ) + delay_callback = sound_callback + else: + delay_callback = riva.client.sleep_audio_length if args.simulate_realtime else None + + with riva.client.AudioChunkFileIterator( + file, args.chunk_duration_ms, delay_callback, + ) as audio_chunk_iterator: + riva.client.print_streaming( + responses=asr_service.streaming_response_generator( + audio_chunks=audio_chunk_iterator, streaming_config=config, + ), + input_file=file, + show_intermediate=args.interim_results, + additional_info="confidence" if args.print_confidence else "no", + ) + finally: + if sound_callback is not None and sound_callback.opened: + sound_callback.close() if __name__ == "__main__": From 1fed1e5c856940eccfee23570ea9f6776ae3bfb5 Mon Sep 17 00:00:00 2001 From: Viraj Karandikar Date: Fri, 23 Feb 2024 15:19:34 +0530 Subject: [PATCH 3/4] Set verbatim-transcripts to False by default --- riva/client/argparse_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index 3618fca8..5941aade 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -37,9 +37,9 @@ def add_asr_config_argparse_parameters( ) parser.add_argument( "--verbatim-transcripts", - default=True, - action='store_false', - help="True returns text exactly as it was said. False applies Inverse text normalization", + default=False, + action='store_true', + help="Flag to disable Inverse text normalization and return the text exactly as it was said", ) parser.add_argument("--language-code", default="en-US", help="Language code of the model to be used.") parser.add_argument("--model-name", default="", help="Name of the model to be used to be used.") From e0699835cd389403c90cf57b727af195ae17b916 Mon Sep 17 00:00:00 2001 From: Viraj Karandikar Date: Tue, 16 Apr 2024 13:03:04 +0530 Subject: [PATCH 4/4] minor updates --- riva/client/asr.py | 4 ++-- scripts/asr/riva_streaming_asr_client.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/riva/client/asr.py b/riva/client/asr.py index 77ce524e..ecbbe626 100644 --- a/riva/client/asr.py +++ b/riva/client/asr.py @@ -237,7 +237,7 @@ def print_streaming( for i, alternative in enumerate(result.alternatives): for f in output_file: f.write( - f"Time {time.time() - start_time:.2f}s: Transcript {i}: {alternative.transcript}\n" + f"Time {time.time() - start_time:.2f}s: Final Transcript {i}: Audio Processed {result.audio_processed}: {alternative.transcript}\n" ) if word_time_offsets: for f in output_file: @@ -269,7 +269,7 @@ def print_streaming( elif additional_info == 'time': for f in output_file: if partial_transcript: - f.write(f">>>Time {time.time():.2f}s: {partial_transcript}\n") + f.write(f">>>Time {time.time() - start_time:.2f}s: {partial_transcript}\n") else: for f in output_file: f.write('----\n') diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index 65fcb6e4..4be3ecd0 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -61,13 +61,13 @@ def streaming_transcription_worker( max_alternatives=args.max_alternatives, profanity_filter=args.profanity_filter, enable_automatic_punctuation=args.automatic_punctuation, - verbatim_transcripts=not args.no_verbatim_transcripts, + verbatim_transcripts=not args.verbatim_transcripts, enable_word_time_offsets=args.word_time_offsets, model=args.model_name, ), interim_results=args.interim_results, ) - riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score) + riva.client.add_word_boosting_to_config(config, args.boosted_words_file, args.boosted_words_score) for _ in range(args.num_iterations): with riva.client.AudioChunkFileIterator( args.input_file, @@ -80,6 +80,7 @@ def streaming_transcription_worker( ), input_file=args.input_file, output_file=output_file, + additional_info='time', word_time_offsets=args.word_time_offsets, ) except BaseException as e: