@@ -27,17 +27,19 @@ def dataset_iterator(dataset):
2727 }
2828
2929
30- def write_audio (buffer ) -> list :
31- if os .path .exists (DATA_CACHE_DIR ):
32- shutil .rmtree (DATA_CACHE_DIR , ignore_errors = True )
30+ def write_audio (buffer , cache_prefix ) -> list :
31+ cache_dir = os .path .join (DATA_CACHE_DIR , cache_prefix )
32+
33+ if os .path .exists (cache_dir ):
34+ shutil .rmtree (cache_dir , ignore_errors = True )
3335
34- os .makedirs (DATA_CACHE_DIR )
36+ os .makedirs (cache_dir )
3537
3638 data_paths = []
3739 for idx , data in enumerate (buffer ):
3840 fn = os .path .basename (data ['audio_filename' ])
3941 fn = os .path .splitext (fn )[0 ]
40- path = os .path .join (DATA_CACHE_DIR , f"{ idx } _{ fn } .wav" )
42+ path = os .path .join (cache_dir , f"{ idx } _{ fn } .wav" )
4143 data_paths .append (path )
4244
4345 soundfile .write (path , data ["array" ], samplerate = data ['sample_rate' ])
@@ -52,14 +54,14 @@ def pack_results(results: list, buffer, transcriptions):
5254 return results
5355
5456
55- def buffer_audio_and_transcribe (model : ASRModel , dataset , batch_size : int , verbose : bool = True ):
57+ def buffer_audio_and_transcribe (model : ASRModel , dataset , batch_size : int , cache_prefix : str , verbose : bool = True ):
5658 buffer = []
5759 results = []
5860 for sample in tqdm (dataset_iterator (dataset ), desc = 'Evaluating: Sample id' , unit = '' , disable = not verbose ):
5961 buffer .append (sample )
6062
6163 if len (buffer ) == batch_size :
62- filepaths = write_audio (buffer )
64+ filepaths = write_audio (buffer , cache_prefix )
6365 transcriptions = model .transcribe (filepaths , batch_size = batch_size , verbose = False )
6466 # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
6567 if type (transcriptions ) == tuple and len (transcriptions ) == 2 :
@@ -68,7 +70,7 @@ def buffer_audio_and_transcribe(model: ASRModel, dataset, batch_size: int, verbo
6870 buffer .clear ()
6971
7072 if len (buffer ) > 0 :
71- filepaths = write_audio (buffer )
73+ filepaths = write_audio (buffer , cache_prefix )
7274 transcriptions = model .transcribe (filepaths , batch_size = batch_size , verbose = False )
7375 # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
7476 if type (transcriptions ) == tuple and len (transcriptions ) == 2 :
@@ -105,7 +107,9 @@ def main(args):
105107 references = []
106108
107109 # run streamed inference
108- results = buffer_audio_and_transcribe (asr_model , dataset , args .batch_size , verbose = True )
110+ cache_prefix = (f"{ args .model_id .replace ('/' , '-' )} -{ args .dataset_path .replace ('/' , '' )} -"
111+ f"{ args .dataset .replace ('/' , '-' )} -{ args .split } " )
112+ results = buffer_audio_and_transcribe (asr_model , dataset , args .batch_size , cache_prefix , verbose = True )
109113 for sample in results :
110114 predictions .append (data_utils .normalizer (sample ["pred_text" ]))
111115 references .append (sample ["reference" ])
0 commit comments