Skip to content

Commit f89906f

Browse files
updates: add options to ASR and TTS clients (#119)
- Add/update list model option to ASR clients - Add encoding option to TTS client
1 parent 340e1e3 commit f89906f

File tree

3 files changed

+38
-5
lines changed

3 files changed

+38
-5
lines changed

scripts/asr/transcribe_file.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,13 @@ def main() -> None:
7272
asr_models = dict()
7373
config_response = asr_service.stub.GetRivaSpeechRecognitionConfig(riva.client.proto.riva_asr_pb2.RivaSpeechRecognitionConfigRequest())
7474
for model_config in config_response.model_config:
75-
if model_config.parameters["streaming"] and model_config.parameters["type"]:
75+
if model_config.parameters["type"] == "online":
7676
language_code = model_config.parameters['language_code']
77+
model = {"model": [model_config.model_name]}
7778
if language_code in asr_models:
78-
asr_models[language_code]["models"].append(model_config.model_name)
79+
asr_models[language_code].append(model)
7980
else:
80-
asr_models[language_code] = {"models": [model_config.model_name]}
81+
asr_models[language_code] = [model]
8182

8283
print("Available ASR models")
8384
asr_models = dict(sorted(asr_models.items()))

scripts/asr/transcribe_file_offline.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: MIT
33

4+
import os
45
import argparse
56
from pathlib import Path
67

@@ -16,18 +17,45 @@ def parse_args() -> argparse.Namespace:
1617
"one response.",
1718
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1819
)
19-
parser.add_argument("--input-file", required=True, type=Path, help="A path to a local file to transcribe.")
20+
group = parser.add_mutually_exclusive_group(required=True)
21+
group.add_argument("--input-file", type=Path, help="A path to a local file to transcribe.")
22+
group.add_argument("--list-models", action="store_true", help="List available models.")
23+
2024
parser = add_connection_argparse_parameters(parser)
2125
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
2226
args = parser.parse_args()
23-
args.input_file = args.input_file.expanduser()
27+
if args.input_file:
28+
args.input_file = args.input_file.expanduser()
2429
return args
2530

2631

2732
def main() -> None:
2833
args = parse_args()
34+
2935
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
3036
asr_service = riva.client.ASRService(auth)
37+
38+
if args.list_models:
39+
asr_models = dict()
40+
config_response = asr_service.stub.GetRivaSpeechRecognitionConfig(riva.client.proto.riva_asr_pb2.RivaSpeechRecognitionConfigRequest())
41+
for model_config in config_response.model_config:
42+
if model_config.parameters["type"] == "offline":
43+
language_code = model_config.parameters['language_code']
44+
model = {"model": [model_config.model_name]}
45+
if language_code in asr_models:
46+
asr_models[language_code].append(model)
47+
else:
48+
asr_models[language_code] = [model]
49+
50+
print("Available ASR models")
51+
asr_models = dict(sorted(asr_models.items()))
52+
print(asr_models)
53+
return
54+
55+
if not os.path.isfile(args.input_file):
56+
print(f"Invalid input file path: {args.input_file}")
57+
return
58+
3159
config = riva.client.RecognitionConfig(
3260
language_code=args.language_code,
3361
max_alternatives=args.max_alternatives,

scripts/tts/talk.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import riva.client
1111
from riva.client.argparse_utils import add_connection_argparse_parameters
12+
from riva.client.proto.riva_audio_pb2 import AudioEncoding
1213

1314
def read_file_to_dict(file_path):
1415
result_dict = {}
@@ -56,6 +57,7 @@ def parse_args() -> argparse.Namespace:
5657
parser.add_argument(
5758
"--sample-rate-hz", type=int, default=44100, help="Number of audio frames per second in synthesized audio."
5859
)
60+
parser.add_argument("--encoding", default="LINEAR_PCM", choices={"LINEAR_PCM", "OGGOPUS"}, help="Output audio encoding.")
5961
parser.add_argument("--custom-dictionary", type=str, help="A file path to a user dictionary with key-value pairs separated by double spaces.")
6062
parser.add_argument(
6163
"--stream",
@@ -132,6 +134,7 @@ def main() -> None:
132134
if args.stream:
133135
responses = service.synthesize_online(
134136
args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz,
137+
encoding=AudioEncoding.OGGOPUS if args.encoding == "OGGOPUS" else AudioEncoding.LINEAR_PCM,
135138
audio_prompt_file=args.audio_prompt_file, quality=20 if args.quality is None else args.quality,
136139
custom_dictionary=custom_dictionary_input
137140
)
@@ -148,6 +151,7 @@ def main() -> None:
148151
else:
149152
resp = service.synthesize(
150153
args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz,
154+
encoding=AudioEncoding.OGGOPUS if args.encoding == "OGGOPUS" else AudioEncoding.LINEAR_PCM,
151155
audio_prompt_file=args.audio_prompt_file, quality=20 if args.quality is None else args.quality,
152156
custom_dictionary=custom_dictionary_input
153157
)

0 commit comments

Comments
 (0)