From 11153fca80876994e3516d9f9c30e023e8cbb0fe Mon Sep 17 00:00:00 2001 From: Viraj Karandikar Date: Wed, 17 Apr 2024 10:28:22 +0530 Subject: [PATCH 1/2] asr: add all config params --- riva/client/argparse_utils.py | 31 ++++++++++++++++++------ scripts/asr/riva_streaming_asr_client.py | 17 ++++++++----- scripts/asr/transcribe_file.py | 18 +++++++++----- scripts/asr/transcribe_file_offline.py | 12 ++++++--- 4 files changed, 56 insertions(+), 22 deletions(-) diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index 4e0dea87..8b85aabf 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -2,10 +2,14 @@ # SPDX-License-Identifier: MIT import argparse +import riva.client def add_asr_config_argparse_parameters( - parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, word_time_offsets: bool = False + parser: argparse.ArgumentParser, + max_alternatives: bool = False, + profanity_filter: bool = False, + word_time_offsets: bool = False, ) -> argparse.ArgumentParser: if word_time_offsets: parser.add_argument( @@ -20,10 +24,22 @@ def add_asr_config_argparse_parameters( ) if profanity_filter: parser.add_argument( - "--profanity-filter", - default=False, - action='store_true', - help="Flag that controls the profanity filtering in the generated transcripts", + "--profanity-filter", + default=False, + action='store_true', + help="Flag that controls the profanity filtering in the generated transcripts", + ) + parser.add_argument( + "--encoding", + type=int, + default=riva.client.AudioEncoding.Value('ENCODING_UNSPECIFIED'), + help=f"Encoding of the audio data. Supported values are {riva.client.AudioEncoding.items()}", + ) + parser.add_argument( + "--sample-rate-hertz", type=int, default=None, help=f"Sample rate in Hz of the audio data", + ) + parser.add_argument( + "--audio-channel-count", type=int, default=None, help=f"Channel count of the audio data", ) parser.add_argument( "--automatic-punctuation", @@ -32,11 +48,12 @@ def add_asr_config_argparse_parameters( help="Flag that controls if transcript should be automatically punctuated", ) parser.add_argument( - "--no-verbatim-transcripts", + "--verbatim-transcripts", default=False, action='store_true', - help="If specified, text inverse normalization will be applied", + help="Flag to disable Inverse Text Normalization (ITN) and return the text exactly as it was said", ) + parser.add_argument("--model-name", default="", help="Name of the model to be used to be used.") parser.add_argument("--language-code", default="en-US", help="Language code of the model to be used.") parser.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding.") parser.add_argument( diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index 6c7785ec..85c97e05 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -10,8 +10,8 @@ from typing import Union import riva.client -from riva.client.asr import get_wav_file_parameters from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters +from riva.client.asr import get_wav_file_parameters def parse_args() -> argparse.Namespace: @@ -38,7 +38,9 @@ def parse_args() -> argparse.Namespace: "--file-streaming-chunk", type=int, default=1600, help="Number of frames in one chunk sent to server." ) parser = add_connection_argparse_parameters(parser) - parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True) + parser = add_asr_config_argparse_parameters( + parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True + ) args = parser.parse_args() if args.max_alternatives < 1: parser.error("`--max-alternatives` must be greater than or equal to 1") @@ -54,12 +56,16 @@ def streaming_transcription_worker( asr_service = riva.client.ASRService(auth) config = riva.client.StreamingRecognitionConfig( config=riva.client.RecognitionConfig( + encoding=args.encoding, + sample_rate_hertz=args.sample_rate_hertz, language_code=args.language_code, max_alternatives=args.max_alternatives, profanity_filter=args.profanity_filter, - enable_automatic_punctuation=args.automatic_punctuation, - verbatim_transcripts=not args.no_verbatim_transcripts, + audio_channel_count=args.audio_channel_count, enable_word_time_offsets=args.word_time_offsets, + enable_automatic_punctuation=args.automatic_punctuation, + model=args.model_name, + verbatim_transcripts=args.verbatim_transcripts, ), interim_results=True, ) @@ -72,8 +78,7 @@ def streaming_transcription_worker( ) as audio_chunk_iterator: riva.client.print_streaming( responses=asr_service.streaming_response_generator( - audio_chunks=audio_chunk_iterator, - streaming_config=config, + audio_chunks=audio_chunk_iterator, streaming_config=config, ), output_file=output_file, additional_info='time', diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index 8c6a6238..2bcbe381 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -25,7 +25,7 @@ def parse_args() -> argparse.Namespace: type=int, default=None, help="Output audio device to use for playing audio simultaneously with transcribing. If this parameter is " - "provided, then you do not have to `--play-audio` option." + "provided, then you do not have to `--play-audio` option.", ) parser.add_argument( "--play-audio", @@ -49,7 +49,9 @@ def parse_args() -> argparse.Namespace: "--print-confidence", action="store_true", help="Whether to print stability and confidence of transcript." ) parser = add_connection_argparse_parameters(parser) - parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True) + parser = add_asr_config_argparse_parameters( + parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True + ) args = parser.parse_args() if not args.list_devices and args.input_file is None: parser.error( @@ -70,11 +72,16 @@ def main() -> None: asr_service = riva.client.ASRService(auth) config = riva.client.StreamingRecognitionConfig( config=riva.client.RecognitionConfig( + encoding=args.encoding, + sample_rate_hertz=args.sample_rate_hertz, language_code=args.language_code, - max_alternatives=1, + max_alternatives=args.max_alternatives, profanity_filter=args.profanity_filter, + audio_channel_count=args.audio_channel_count, + enable_word_time_offsets=args.word_time_offsets, enable_automatic_punctuation=args.automatic_punctuation, - verbatim_transcripts=not args.no_verbatim_transcripts, + model=args.model_name, + verbatim_transcripts=args.verbatim_transcripts, ), interim_results=True, ) @@ -94,8 +101,7 @@ def main() -> None: ) as audio_chunk_iterator: riva.client.print_streaming( responses=asr_service.streaming_response_generator( - audio_chunks=audio_chunk_iterator, - streaming_config=config, + audio_chunks=audio_chunk_iterator, streaming_config=config, ), show_intermediate=args.show_intermediate, additional_info="confidence" if args.print_confidence else "no", diff --git a/scripts/asr/transcribe_file_offline.py b/scripts/asr/transcribe_file_offline.py index 5586a39a..65d415f5 100644 --- a/scripts/asr/transcribe_file_offline.py +++ b/scripts/asr/transcribe_file_offline.py @@ -18,7 +18,9 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--input-file", required=True, type=Path, help="A path to a local file to transcribe.") parser = add_connection_argparse_parameters(parser) - parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True) + parser = add_asr_config_argparse_parameters( + parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True + ) args = parser.parse_args() args.input_file = args.input_file.expanduser() return args @@ -29,12 +31,16 @@ def main() -> None: auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) asr_service = riva.client.ASRService(auth) config = riva.client.RecognitionConfig( + encoding=args.encoding, + sample_rate_hertz=args.sample_rate_hertz, language_code=args.language_code, max_alternatives=args.max_alternatives, profanity_filter=args.profanity_filter, - enable_automatic_punctuation=args.automatic_punctuation, - verbatim_transcripts=not args.no_verbatim_transcripts, + audio_channel_count=args.audio_channel_count, enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization, + enable_automatic_punctuation=args.automatic_punctuation, + model=args.model_name, + verbatim_transcripts=args.verbatim_transcripts, ) riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score) riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization) From 89d338a9affc5eb5f74f0ade94496391cd142cd9 Mon Sep 17 00:00:00 2001 From: Viraj Karandikar Date: Wed, 17 Apr 2024 10:54:24 +0530 Subject: [PATCH 2/2] fix printing word timestamps --- scripts/asr/transcribe_file.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index 2bcbe381..61422b66 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -50,7 +50,7 @@ def parse_args() -> argparse.Namespace: ) parser = add_connection_argparse_parameters(parser) parser = add_asr_config_argparse_parameters( - parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True + parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True, ) args = parser.parse_args() if not args.list_devices and args.input_file is None: @@ -83,7 +83,7 @@ def main() -> None: model=args.model_name, verbatim_transcripts=args.verbatim_transcripts, ), - interim_results=True, + interim_results=args.show_intermediate, ) riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score) sound_callback = None @@ -104,7 +104,8 @@ def main() -> None: audio_chunks=audio_chunk_iterator, streaming_config=config, ), show_intermediate=args.show_intermediate, - additional_info="confidence" if args.print_confidence else "no", + additional_info="confidence" if args.print_confidence else "time" if args.word_time_offsets else "no", + word_time_offsets=args.word_time_offsets, ) finally: if sound_callback is not None and sound_callback.opened: