1
1
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
# SPDX-License-Identifier: MIT
3
3
4
+ import os
4
5
import argparse
5
6
from pathlib import Path
6
7
@@ -16,18 +17,45 @@ def parse_args() -> argparse.Namespace:
16
17
"one response." ,
17
18
formatter_class = argparse .ArgumentDefaultsHelpFormatter ,
18
19
)
19
- parser .add_argument ("--input-file" , required = True , type = Path , help = "A path to a local file to transcribe." )
20
+ group = parser .add_mutually_exclusive_group (required = True )
21
+ group .add_argument ("--input-file" , type = Path , help = "A path to a local file to transcribe." )
22
+ group .add_argument ("--list-models" , action = "store_true" , help = "List available models." )
23
+
20
24
parser = add_connection_argparse_parameters (parser )
21
25
parser = add_asr_config_argparse_parameters (parser , max_alternatives = True , profanity_filter = True , word_time_offsets = True )
22
26
args = parser .parse_args ()
23
- args .input_file = args .input_file .expanduser ()
27
+ if args .input_file :
28
+ args .input_file = args .input_file .expanduser ()
24
29
return args
25
30
26
31
27
32
def main () -> None :
28
33
args = parse_args ()
34
+
29
35
auth = riva .client .Auth (args .ssl_cert , args .use_ssl , args .server , args .metadata )
30
36
asr_service = riva .client .ASRService (auth )
37
+
38
+ if args .list_models :
39
+ asr_models = dict ()
40
+ config_response = asr_service .stub .GetRivaSpeechRecognitionConfig (riva .client .proto .riva_asr_pb2 .RivaSpeechRecognitionConfigRequest ())
41
+ for model_config in config_response .model_config :
42
+ if model_config .parameters ["type" ] == "offline" :
43
+ language_code = model_config .parameters ['language_code' ]
44
+ model = {"model" : [model_config .model_name ]}
45
+ if language_code in asr_models :
46
+ asr_models [language_code ].append (model )
47
+ else :
48
+ asr_models [language_code ] = [model ]
49
+
50
+ print ("Available ASR models" )
51
+ asr_models = dict (sorted (asr_models .items ()))
52
+ print (asr_models )
53
+ return
54
+
55
+ if not os .path .isfile (args .input_file ):
56
+ print (f"Invalid input file path: { args .input_file } " )
57
+ return
58
+
31
59
config = riva .client .RecognitionConfig (
32
60
language_code = args .language_code ,
33
61
max_alternatives = args .max_alternatives ,
0 commit comments