@@ -54,15 +54,19 @@ def pack_results(results: list, buffer, transcriptions):
5454 return results
5555
5656
57- def buffer_audio_and_transcribe (model : ASRModel , dataset , batch_size : int , cache_prefix : str , verbose : bool = True ):
57+ def buffer_audio_and_transcribe (model : ASRModel , dataset , batch_size : int , pnc : bool , cache_prefix : str , verbose : bool = True ):
5858 buffer = []
5959 results = []
6060 for sample in tqdm (dataset_iterator (dataset ), desc = 'Evaluating: Sample id' , unit = '' , disable = not verbose ):
6161 buffer .append (sample )
6262
6363 if len (buffer ) == batch_size :
6464 filepaths = write_audio (buffer , cache_prefix )
65- transcriptions = model .transcribe (filepaths , batch_size = batch_size , verbose = False )
65+
66+ if pnc is not None :
67+ transcriptions = model .transcribe (filepaths , batch_size = batch_size , pnc = False , verbose = False )
68+ else :
69+ transcriptions = model .transcribe (filepaths , batch_size = batch_size , verbose = False )
6670 # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
6771 if type (transcriptions ) == tuple and len (transcriptions ) == 2 :
6872 transcriptions = transcriptions [0 ]
@@ -71,7 +75,10 @@ def buffer_audio_and_transcribe(model: ASRModel, dataset, batch_size: int, cache
7175
7276 if len (buffer ) > 0 :
7377 filepaths = write_audio (buffer , cache_prefix )
74- transcriptions = model .transcribe (filepaths , batch_size = batch_size , verbose = False )
78+ if pnc is not None :
79+ transcriptions = model .transcribe (filepaths , batch_size = batch_size , pnc = False , verbose = False )
80+ else :
81+ transcriptions = model .transcribe (filepaths , batch_size = batch_size , verbose = False )
7582 # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
7683 if type (transcriptions ) == tuple and len (transcriptions ) == 2 :
7784 transcriptions = transcriptions [0 ]
@@ -112,7 +119,7 @@ def main(args):
112119 # run streamed inference
113120 cache_prefix = (f"{ args .model_id .replace ('/' , '-' )} -{ args .dataset_path .replace ('/' , '' )} -"
114121 f"{ args .dataset .replace ('/' , '-' )} -{ args .split } " )
115- results = buffer_audio_and_transcribe (asr_model , dataset , args .batch_size , cache_prefix , verbose = True )
122+ results = buffer_audio_and_transcribe (asr_model , dataset , args .batch_size , args . pnc , cache_prefix , verbose = True )
116123 for sample in results :
117124 predictions .append (data_utils .normalizer (sample ["pred_text" ]))
118125 references .append (sample ["reference" ])
@@ -166,6 +173,12 @@ def main(args):
166173 default = None ,
167174 help = "Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script." ,
168175 )
176+ parser .add_argument (
177+ "--pnc" ,
178+ type = bool ,
179+ default = None ,
180+ help = "flag to indicate inferene in pnc mode for models that support punctuation and capitalization" ,
181+ )
169182 parser .add_argument (
170183 "--no-streaming" ,
171184 dest = 'streaming' ,
0 commit comments