diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index 4e0dea87..8b85aabf 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -2,10 +2,14 @@ # SPDX-License-Identifier: MIT import argparse +import riva.client def add_asr_config_argparse_parameters( - parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, word_time_offsets: bool = False + parser: argparse.ArgumentParser, + max_alternatives: bool = False, + profanity_filter: bool = False, + word_time_offsets: bool = False, ) -> argparse.ArgumentParser: if word_time_offsets: parser.add_argument( @@ -20,10 +24,22 @@ def add_asr_config_argparse_parameters( ) if profanity_filter: parser.add_argument( - "--profanity-filter", - default=False, - action='store_true', - help="Flag that controls the profanity filtering in the generated transcripts", + "--profanity-filter", + default=False, + action='store_true', + help="Flag that controls the profanity filtering in the generated transcripts", + ) + parser.add_argument( + "--encoding", + type=int, + default=riva.client.AudioEncoding.Value('ENCODING_UNSPECIFIED'), + help=f"Encoding of the audio data. Supported values are {riva.client.AudioEncoding.items()}", + ) + parser.add_argument( + "--sample-rate-hertz", type=int, default=None, help=f"Sample rate in Hz of the audio data", + ) + parser.add_argument( + "--audio-channel-count", type=int, default=None, help=f"Channel count of the audio data", ) parser.add_argument( "--automatic-punctuation", @@ -32,11 +48,12 @@ def add_asr_config_argparse_parameters( help="Flag that controls if transcript should be automatically punctuated", ) parser.add_argument( - "--no-verbatim-transcripts", + "--verbatim-transcripts", default=False, action='store_true', - help="If specified, text inverse normalization will be applied", + help="Flag to disable Inverse Text Normalization (ITN) and return the text exactly as it was said", ) + parser.add_argument("--model-name", default="", help="Name of the model to be used to be used.") parser.add_argument("--language-code", default="en-US", help="Language code of the model to be used.") parser.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding.") parser.add_argument( diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index 6c7785ec..85c97e05 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -10,8 +10,8 @@ from typing import Union import riva.client -from riva.client.asr import get_wav_file_parameters from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters +from riva.client.asr import get_wav_file_parameters def parse_args() -> argparse.Namespace: @@ -38,7 +38,9 @@ def parse_args() -> argparse.Namespace: "--file-streaming-chunk", type=int, default=1600, help="Number of frames in one chunk sent to server." ) parser = add_connection_argparse_parameters(parser) - parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True) + parser = add_asr_config_argparse_parameters( + parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True + ) args = parser.parse_args() if args.max_alternatives < 1: parser.error("`--max-alternatives` must be greater than or equal to 1") @@ -54,12 +56,16 @@ def streaming_transcription_worker( asr_service = riva.client.ASRService(auth) config = riva.client.StreamingRecognitionConfig( config=riva.client.RecognitionConfig( + encoding=args.encoding, + sample_rate_hertz=args.sample_rate_hertz, language_code=args.language_code, max_alternatives=args.max_alternatives, profanity_filter=args.profanity_filter, - enable_automatic_punctuation=args.automatic_punctuation, - verbatim_transcripts=not args.no_verbatim_transcripts, + audio_channel_count=args.audio_channel_count, enable_word_time_offsets=args.word_time_offsets, + enable_automatic_punctuation=args.automatic_punctuation, + model=args.model_name, + verbatim_transcripts=args.verbatim_transcripts, ), interim_results=True, ) @@ -72,8 +78,7 @@ def streaming_transcription_worker( ) as audio_chunk_iterator: riva.client.print_streaming( responses=asr_service.streaming_response_generator( - audio_chunks=audio_chunk_iterator, - streaming_config=config, + audio_chunks=audio_chunk_iterator, streaming_config=config, ), output_file=output_file, additional_info='time', diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index 8c6a6238..61422b66 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -25,7 +25,7 @@ def parse_args() -> argparse.Namespace: type=int, default=None, help="Output audio device to use for playing audio simultaneously with transcribing. If this parameter is " - "provided, then you do not have to `--play-audio` option." + "provided, then you do not have to `--play-audio` option.", ) parser.add_argument( "--play-audio", @@ -49,7 +49,9 @@ def parse_args() -> argparse.Namespace: "--print-confidence", action="store_true", help="Whether to print stability and confidence of transcript." ) parser = add_connection_argparse_parameters(parser) - parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True) + parser = add_asr_config_argparse_parameters( + parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True, + ) args = parser.parse_args() if not args.list_devices and args.input_file is None: parser.error( @@ -70,13 +72,18 @@ def main() -> None: asr_service = riva.client.ASRService(auth) config = riva.client.StreamingRecognitionConfig( config=riva.client.RecognitionConfig( + encoding=args.encoding, + sample_rate_hertz=args.sample_rate_hertz, language_code=args.language_code, - max_alternatives=1, + max_alternatives=args.max_alternatives, profanity_filter=args.profanity_filter, + audio_channel_count=args.audio_channel_count, + enable_word_time_offsets=args.word_time_offsets, enable_automatic_punctuation=args.automatic_punctuation, - verbatim_transcripts=not args.no_verbatim_transcripts, + model=args.model_name, + verbatim_transcripts=args.verbatim_transcripts, ), - interim_results=True, + interim_results=args.show_intermediate, ) riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score) sound_callback = None @@ -94,11 +101,11 @@ def main() -> None: ) as audio_chunk_iterator: riva.client.print_streaming( responses=asr_service.streaming_response_generator( - audio_chunks=audio_chunk_iterator, - streaming_config=config, + audio_chunks=audio_chunk_iterator, streaming_config=config, ), show_intermediate=args.show_intermediate, - additional_info="confidence" if args.print_confidence else "no", + additional_info="confidence" if args.print_confidence else "time" if args.word_time_offsets else "no", + word_time_offsets=args.word_time_offsets, ) finally: if sound_callback is not None and sound_callback.opened: diff --git a/scripts/asr/transcribe_file_offline.py b/scripts/asr/transcribe_file_offline.py index 5586a39a..65d415f5 100644 --- a/scripts/asr/transcribe_file_offline.py +++ b/scripts/asr/transcribe_file_offline.py @@ -18,7 +18,9 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--input-file", required=True, type=Path, help="A path to a local file to transcribe.") parser = add_connection_argparse_parameters(parser) - parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True) + parser = add_asr_config_argparse_parameters( + parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True + ) args = parser.parse_args() args.input_file = args.input_file.expanduser() return args @@ -29,12 +31,16 @@ def main() -> None: auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) asr_service = riva.client.ASRService(auth) config = riva.client.RecognitionConfig( + encoding=args.encoding, + sample_rate_hertz=args.sample_rate_hertz, language_code=args.language_code, max_alternatives=args.max_alternatives, profanity_filter=args.profanity_filter, - enable_automatic_punctuation=args.automatic_punctuation, - verbatim_transcripts=not args.no_verbatim_transcripts, + audio_channel_count=args.audio_channel_count, enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization, + enable_automatic_punctuation=args.automatic_punctuation, + model=args.model_name, + verbatim_transcripts=args.verbatim_transcripts, ) riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score) riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization)