Skip to content

Commit 157576c

Browse files
asr: add model name parameter (#85)
* asr: add model name parameter * minor fixes to asr and tts clients * asr add file path check
1 parent 330aa60 commit 157576c

File tree

5 files changed

+40
-14
lines changed

5 files changed

+40
-14
lines changed

riva/client/argparse_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ def add_asr_config_argparse_parameters(
3838
help="If specified, text inverse normalization will be applied",
3939
)
4040
parser.add_argument("--language-code", default="en-US", help="Language code of the model to be used.")
41-
parser.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding.")
41+
parser.add_argument("--model-name", default="", help="Model name to be used.")
42+
parser.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding. Can be used multiple times to boost multiple words.")
4243
parser.add_argument(
43-
"--boosted-lm-score", type=float, default=4.0, help="Value by which to boost words when decoding."
44+
"--boosted-lm-score", type=float, default=4.0, help="Recommended range for the boost score is 20 to 100. The higher the boost score, the more biased the ASR engine is towards this word."
4445
)
4546
parser.add_argument(
4647
"--speaker-diarization",

scripts/asr/riva_streaming_asr_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def streaming_transcription_worker(
5555
config = riva.client.StreamingRecognitionConfig(
5656
config=riva.client.RecognitionConfig(
5757
language_code=args.language_code,
58+
model=args.model_name,
5859
max_alternatives=args.max_alternatives,
5960
profanity_filter=args.profanity_filter,
6061
enable_automatic_punctuation=args.automatic_punctuation,

scripts/asr/transcribe_file.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import argparse
55

6+
import os
67
import riva.client
78
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
89

@@ -15,8 +16,11 @@ def parse_args() -> argparse.Namespace:
1516
"`--play-audio` or `--output-device`.",
1617
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1718
)
18-
parser.add_argument("--input-file", help="A path to a local file to stream.")
19-
parser.add_argument("--list-devices", action="store_true", help="List output devices indices")
19+
group = parser.add_mutually_exclusive_group(required=True)
20+
group.add_argument("--input-file", help="A path to a local file to stream.")
21+
group.add_argument("--list-models", action="store_true", help="List available models.")
22+
group.add_argument("--list-devices", action="store_true", help="List output devices indices")
23+
2024
parser.add_argument(
2125
"--show-intermediate", action="store_true", help="Show intermediate transcripts as they are available."
2226
)
@@ -51,11 +55,6 @@ def parse_args() -> argparse.Namespace:
5155
parser = add_connection_argparse_parameters(parser)
5256
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
5357
args = parser.parse_args()
54-
if not args.list_devices and args.input_file is None:
55-
parser.error(
56-
"You have to provide at least one of parameters `--input-file` and `--list-devices` whereas both "
57-
"parameters are missing."
58-
)
5958
if args.play_audio or args.output_device is not None or args.list_devices:
6059
import riva.client.audio_io
6160
return args
@@ -68,9 +67,31 @@ def main() -> None:
6867
return
6968
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
7069
asr_service = riva.client.ASRService(auth)
70+
71+
if args.list_models:
72+
asr_models = dict()
73+
config_response = asr_service.stub.GetRivaSpeechRecognitionConfig(riva.client.proto.riva_asr_pb2.RivaSpeechRecognitionConfigRequest())
74+
for model_config in config_response.model_config:
75+
if model_config.parameters["streaming"] and model_config.parameters["type"]:
76+
language_code = model_config.parameters['language_code']
77+
if language_code in asr_models:
78+
asr_models[language_code]["models"].append(model_config.model_name)
79+
else:
80+
asr_models[language_code] = {"models": [model_config.model_name]}
81+
82+
print("Available ASR models")
83+
asr_models = dict(sorted(asr_models.items()))
84+
print(asr_models)
85+
return
86+
87+
if not os.path.isfile(args.input_file):
88+
print(f"Invalid input file path: {args.input_file}")
89+
return
90+
7191
config = riva.client.StreamingRecognitionConfig(
7292
config=riva.client.RecognitionConfig(
7393
language_code=args.language_code,
94+
model=args.model_name,
7495
max_alternatives=1,
7596
profanity_filter=args.profanity_filter,
7697
enable_automatic_punctuation=args.automatic_punctuation,

scripts/asr/transcribe_mic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def main() -> None:
4747
config=riva.client.RecognitionConfig(
4848
encoding=riva.client.AudioEncoding.LINEAR_PCM,
4949
language_code=args.language_code,
50+
model=args.model_name,
5051
max_alternatives=1,
5152
profanity_filter=args.profanity_filter,
5253
enable_automatic_punctuation=args.automatic_punctuation,

scripts/tts/talk.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,18 @@
1313

1414
def parse_args() -> argparse.Namespace:
1515
parser = argparse.ArgumentParser(
16-
description="A speech synthesis via Riva AI Services. You HAVE TO provide at least one of arguments "
17-
"`--output`, `--play-audio`, `--list-devices`, `--output-device`.",
16+
description="Speech synthesis via Riva AI Services",
1817
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1918
)
19+
group = parser.add_mutually_exclusive_group(required=True)
20+
group.add_argument("--text", type=str, help="Text input to synthesize.")
21+
group.add_argument("--list-devices", action="store_true", help="List output audio devices indices.")
22+
group.add_argument("--list-voices", action="store_true", help="List available voices.")
2023
parser.add_argument(
2124
"--voice",
2225
help="A voice name to use. If this parameter is missing, then the server will try a first available model "
2326
"based on parameter `--language-code`.",
2427
)
25-
parser.add_argument("--text", type=str, required=False, help="Text input to synthesize.")
2628
parser.add_argument(
2729
"--audio_prompt_file",
2830
type=Path,
@@ -35,8 +37,6 @@ def parse_args() -> argparse.Namespace:
3537
help="Whether to play input audio simultaneously with transcribing. If `--output-device` is not provided, "
3638
"then the default output audio device will be used.",
3739
)
38-
parser.add_argument("--list-devices", action="store_true", help="List output audio devices indices.")
39-
parser.add_argument("--list-voices", action="store_true", help="List available voices.")
4040
parser.add_argument("--output-device", type=int, help="Output device to use.")
4141
parser.add_argument("--language-code", default='en-US', help="A language of input text.")
4242
parser.add_argument(
@@ -62,6 +62,7 @@ def main() -> None:
6262
args = parse_args()
6363
if args.list_devices:
6464
riva.client.audio_io.list_output_devices()
65+
return
6566

6667
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
6768
service = riva.client.SpeechSynthesisService(auth)
@@ -87,6 +88,7 @@ def main() -> None:
8788

8889
tts_models = dict(sorted(tts_models.items()))
8990
print(json.dumps(tts_models, indent=4))
91+
return
9092

9193
if not args.text:
9294
print("No input text provided")

0 commit comments

Comments
 (0)