Skip to content

asr: add all config params #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
# SPDX-License-Identifier: MIT

import argparse
import riva.client


def add_asr_config_argparse_parameters(
parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, word_time_offsets: bool = False
parser: argparse.ArgumentParser,
max_alternatives: bool = False,
profanity_filter: bool = False,
word_time_offsets: bool = False,
) -> argparse.ArgumentParser:
if word_time_offsets:
parser.add_argument(
Expand All @@ -20,10 +24,22 @@ def add_asr_config_argparse_parameters(
)
if profanity_filter:
parser.add_argument(
"--profanity-filter",
default=False,
action='store_true',
help="Flag that controls the profanity filtering in the generated transcripts",
"--profanity-filter",
default=False,
action='store_true',
help="Flag that controls the profanity filtering in the generated transcripts",
)
parser.add_argument(
"--encoding",
type=int,
default=riva.client.AudioEncoding.Value('ENCODING_UNSPECIFIED'),
help=f"Encoding of the audio data. Supported values are {riva.client.AudioEncoding.items()}",
)
parser.add_argument(
"--sample-rate-hertz", type=int, default=None, help=f"Sample rate in Hz of the audio data",
)
parser.add_argument(
"--audio-channel-count", type=int, default=None, help=f"Channel count of the audio data",
)
parser.add_argument(
"--automatic-punctuation",
Expand All @@ -32,11 +48,12 @@ def add_asr_config_argparse_parameters(
help="Flag that controls if transcript should be automatically punctuated",
)
parser.add_argument(
"--no-verbatim-transcripts",
"--verbatim-transcripts",
default=False,
action='store_true',
help="If specified, text inverse normalization will be applied",
help="Flag to disable Inverse Text Normalization (ITN) and return the text exactly as it was said",
)
parser.add_argument("--model-name", default="", help="Name of the model to be used to be used.")
parser.add_argument("--language-code", default="en-US", help="Language code of the model to be used.")
parser.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding.")
parser.add_argument(
Expand Down
17 changes: 11 additions & 6 deletions scripts/asr/riva_streaming_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from typing import Union

import riva.client
from riva.client.asr import get_wav_file_parameters
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
from riva.client.asr import get_wav_file_parameters


def parse_args() -> argparse.Namespace:
Expand All @@ -38,7 +38,9 @@ def parse_args() -> argparse.Namespace:
"--file-streaming-chunk", type=int, default=1600, help="Number of frames in one chunk sent to server."
)
parser = add_connection_argparse_parameters(parser)
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
parser = add_asr_config_argparse_parameters(
parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True
)
args = parser.parse_args()
if args.max_alternatives < 1:
parser.error("`--max-alternatives` must be greater than or equal to 1")
Expand All @@ -54,12 +56,16 @@ def streaming_transcription_worker(
asr_service = riva.client.ASRService(auth)
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
encoding=args.encoding,
sample_rate_hertz=args.sample_rate_hertz,
language_code=args.language_code,
max_alternatives=args.max_alternatives,
profanity_filter=args.profanity_filter,
enable_automatic_punctuation=args.automatic_punctuation,
verbatim_transcripts=not args.no_verbatim_transcripts,
audio_channel_count=args.audio_channel_count,
enable_word_time_offsets=args.word_time_offsets,
enable_automatic_punctuation=args.automatic_punctuation,
model=args.model_name,
verbatim_transcripts=args.verbatim_transcripts,
),
interim_results=True,
)
Expand All @@ -72,8 +78,7 @@ def streaming_transcription_worker(
) as audio_chunk_iterator:
riva.client.print_streaming(
responses=asr_service.streaming_response_generator(
audio_chunks=audio_chunk_iterator,
streaming_config=config,
audio_chunks=audio_chunk_iterator, streaming_config=config,
),
output_file=output_file,
additional_info='time',
Expand Down
23 changes: 15 additions & 8 deletions scripts/asr/transcribe_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def parse_args() -> argparse.Namespace:
type=int,
default=None,
help="Output audio device to use for playing audio simultaneously with transcribing. If this parameter is "
"provided, then you do not have to `--play-audio` option."
"provided, then you do not have to `--play-audio` option.",
)
parser.add_argument(
"--play-audio",
Expand All @@ -49,7 +49,9 @@ def parse_args() -> argparse.Namespace:
"--print-confidence", action="store_true", help="Whether to print stability and confidence of transcript."
)
parser = add_connection_argparse_parameters(parser)
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
parser = add_asr_config_argparse_parameters(
parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True,
)
args = parser.parse_args()
if not args.list_devices and args.input_file is None:
parser.error(
Expand All @@ -70,13 +72,18 @@ def main() -> None:
asr_service = riva.client.ASRService(auth)
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
encoding=args.encoding,
sample_rate_hertz=args.sample_rate_hertz,
language_code=args.language_code,
max_alternatives=1,
max_alternatives=args.max_alternatives,
profanity_filter=args.profanity_filter,
audio_channel_count=args.audio_channel_count,
enable_word_time_offsets=args.word_time_offsets,
enable_automatic_punctuation=args.automatic_punctuation,
verbatim_transcripts=not args.no_verbatim_transcripts,
model=args.model_name,
verbatim_transcripts=args.verbatim_transcripts,
),
interim_results=True,
interim_results=args.show_intermediate,
)
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
sound_callback = None
Expand All @@ -94,11 +101,11 @@ def main() -> None:
) as audio_chunk_iterator:
riva.client.print_streaming(
responses=asr_service.streaming_response_generator(
audio_chunks=audio_chunk_iterator,
streaming_config=config,
audio_chunks=audio_chunk_iterator, streaming_config=config,
),
show_intermediate=args.show_intermediate,
additional_info="confidence" if args.print_confidence else "no",
additional_info="confidence" if args.print_confidence else "time" if args.word_time_offsets else "no",
word_time_offsets=args.word_time_offsets,
)
finally:
if sound_callback is not None and sound_callback.opened:
Expand Down
12 changes: 9 additions & 3 deletions scripts/asr/transcribe_file_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def parse_args() -> argparse.Namespace:
)
parser.add_argument("--input-file", required=True, type=Path, help="A path to a local file to transcribe.")
parser = add_connection_argparse_parameters(parser)
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
parser = add_asr_config_argparse_parameters(
parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True
)
args = parser.parse_args()
args.input_file = args.input_file.expanduser()
return args
Expand All @@ -29,12 +31,16 @@ def main() -> None:
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
asr_service = riva.client.ASRService(auth)
config = riva.client.RecognitionConfig(
encoding=args.encoding,
sample_rate_hertz=args.sample_rate_hertz,
language_code=args.language_code,
max_alternatives=args.max_alternatives,
profanity_filter=args.profanity_filter,
enable_automatic_punctuation=args.automatic_punctuation,
verbatim_transcripts=not args.no_verbatim_transcripts,
audio_channel_count=args.audio_channel_count,
enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,
enable_automatic_punctuation=args.automatic_punctuation,
model=args.model_name,
verbatim_transcripts=args.verbatim_transcripts,
)
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization)
Expand Down