11# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: MIT
33
4+ import os
45import argparse
56from pathlib import Path
67
@@ -16,18 +17,45 @@ def parse_args() -> argparse.Namespace:
1617 "one response." ,
1718 formatter_class = argparse .ArgumentDefaultsHelpFormatter ,
1819 )
19- parser .add_argument ("--input-file" , required = True , type = Path , help = "A path to a local file to transcribe." )
20+ group = parser .add_mutually_exclusive_group (required = True )
21+ group .add_argument ("--input-file" , type = Path , help = "A path to a local file to transcribe." )
22+ group .add_argument ("--list-models" , action = "store_true" , help = "List available models." )
23+
2024 parser = add_connection_argparse_parameters (parser )
2125 parser = add_asr_config_argparse_parameters (parser , max_alternatives = True , profanity_filter = True , word_time_offsets = True )
2226 args = parser .parse_args ()
23- args .input_file = args .input_file .expanduser ()
27+ if args .input_file :
28+ args .input_file = args .input_file .expanduser ()
2429 return args
2530
2631
2732def main () -> None :
2833 args = parse_args ()
34+
2935 auth = riva .client .Auth (args .ssl_cert , args .use_ssl , args .server , args .metadata )
3036 asr_service = riva .client .ASRService (auth )
37+
38+ if args .list_models :
39+ asr_models = dict ()
40+ config_response = asr_service .stub .GetRivaSpeechRecognitionConfig (riva .client .proto .riva_asr_pb2 .RivaSpeechRecognitionConfigRequest ())
41+ for model_config in config_response .model_config :
42+ if model_config .parameters ["type" ] == "offline" :
43+ language_code = model_config .parameters ['language_code' ]
44+ model = {"model" : [model_config .model_name ]}
45+ if language_code in asr_models :
46+ asr_models [language_code ].append (model )
47+ else :
48+ asr_models [language_code ] = [model ]
49+
50+ print ("Available ASR models" )
51+ asr_models = dict (sorted (asr_models .items ()))
52+ print (asr_models )
53+ return
54+
55+ if not os .path .isfile (args .input_file ):
56+ print (f"Invalid input file path: { args .input_file } " )
57+ return
58+
3159 config = riva .client .RecognitionConfig (
3260 language_code = args .language_code ,
3361 max_alternatives = args .max_alternatives ,
0 commit comments