11"""Run evaluation for ctranslate2 whisper models.""" ""
22import argparse
33import os
4+ import time
45
56import evaluate
67from faster_whisper import WhisperModel
1112wer_metric = evaluate .load ("wer" )
1213
1314
14- def dataset_iterator (dataset ) -> dict :
15- """
16- Iterate over the dataset and yield a dictionary with the audio and reference text.
17-
18- Args:
19- dataset: dataset to iterate over
20-
21- Returns:
22- dictionary: {"audio": audio, "reference": reference}
23- """
24- for item in dataset :
25- yield {** item ["audio" ], "reference" : item ["norm_text" ]}
26-
27-
2815def main (args ) -> None :
2916 """Main function to run evaluation on a dataset."""
3017 asr_model = WhisperModel (
@@ -34,38 +21,69 @@ def main(args) -> None:
3421 device_index = args .device
3522 )
3623
37- dataset = data_utils .load_data (args )
24+ def benchmark (batch ):
25+ start_time = time .time ()
26+ segments , _ = asr_model .transcribe (batch ["audio" ]["array" ], language = "en" )
27+ outputs = [segment ._asdict () for segment in segments ]
28+ batch ["transcription_time_s" ] = time .time () - start_time
29+ batch ["predictions" ] = data_utils .normalizer ("" .join ([segment ["text" ] for segment in outputs ])).strip ()
30+ batch ["references" ] = batch ["norm_text" ]
31+ return batch
3832
39- if args .max_eval_samples is not None and args . max_eval_samples > 0 :
40- print ( f"Subsampling dataset to first { args . max_eval_samples } samples !" )
41- dataset = dataset . take ( args . max_eval_samples )
33+ if args .warmup_steps is not None :
34+ dataset = data_utils . load_data ( args )
35+ dataset = data_utils . prepare_data ( dataset )
4236
43- dataset = data_utils .prepare_data (dataset )
37+ if args .streaming :
38+ warmup_dataset = dataset .take (args .warmup_steps )
39+ else :
40+ warmup_dataset = dataset .select (range (min (args .warmup_steps , len (dataset ))))
41+ warmup_dataset = iter (warmup_dataset .map (benchmark , remove_columns = ["audio" ]))
4442
45- predictions = []
46- references = []
43+ for _ in tqdm ( warmup_dataset , desc = "Warming up..." ):
44+ continue
4745
48- # Run inference
49- for batch in tqdm (dataset_iterator (dataset ), desc = f"Evaluating { args .model_id } " ):
50- segments , _ = asr_model .transcribe (batch ["array" ], language = "en" )
51- outputs = [segment ._asdict () for segment in segments ]
52- transcription = data_utils .normalizer (
53- "" .join ([segment ["text" ] for segment in outputs ])
54- ).strip ()
46+ dataset = data_utils .load_data (args )
47+ if args .max_eval_samples is not None and args .max_eval_samples > 0 :
48+ print (f"Subsampling dataset to first { args .max_eval_samples } samples!" )
49+ if args .streaming :
50+ dataset = dataset .take (args .max_eval_samples )
51+ else :
52+ dataset = dataset .select (range (min (args .max_eval_samples , len (dataset ))))
53+ dataset = data_utils .prepare_data (dataset )
54+
55+ dataset = dataset .map (benchmark , remove_columns = ["audio" ])
5556
56- predictions .append (transcription )
57- references .append (batch ["reference" ])
57+ all_results = {
58+ "audio_length_s" : [],
59+ "transcription_time_s" : [],
60+ "predictions" : [],
61+ "references" : [],
62+ }
63+ result_iter = iter (dataset )
64+ for result in tqdm (result_iter , desc = "Samples..." ):
65+ for key in all_results :
66+ all_results [key ].append (result [key ])
5867
59- # Write manifest results
68+ # Write manifest results (WER and RTFX)
6069 manifest_path = data_utils .write_manifest (
61- references , predictions , args .model_id , args .dataset_path , args .dataset , args .split
70+ all_results ["references" ],
71+ all_results ["predictions" ],
72+ args .model_id ,
73+ args .dataset_path ,
74+ args .dataset ,
75+ args .split ,
76+ audio_length = all_results ["audio_length_s" ],
77+ transcription_time = all_results ["transcription_time_s" ],
6278 )
6379 print ("Results saved at path:" , os .path .abspath (manifest_path ))
6480
65- wer = wer_metric .compute (references = references , predictions = predictions )
81+ wer = wer_metric .compute (
82+ references = all_results ["references" ], predictions = all_results ["predictions" ]
83+ )
6684 wer = round (100 * wer , 2 )
67-
68- print ("WER:" , wer , "%" )
85+ rtfx = round ( sum ( all_results [ "audio_length_s" ]) / sum ( all_results [ "transcription_time_s" ]), 2 )
86+ print ("WER:" , wer , "%" , "RTFx:" , rtfx )
6987
7088
7189if __name__ == "__main__" :
@@ -75,7 +93,7 @@ def main(args) -> None:
7593 "--model_id" ,
7694 type = str ,
7795 required = True ,
78- help = "Model identifier. Should be loadable with 🤗 Transformers " ,
96+ help = "Model identifier. Should be loadable with faster-whisper " ,
7997 )
8098 parser .add_argument (
8199 '--dataset_path' , type = str , default = 'esb/datasets' , help = 'Dataset path. By default, it is `esb/datasets`'
@@ -99,12 +117,6 @@ def main(args) -> None:
99117 default = - 1 ,
100118 help = "The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on." ,
101119 )
102- parser .add_argument (
103- "--batch_size" ,
104- type = int ,
105- default = 16 ,
106- help = "Number of samples to go through each streamed batch." ,
107- )
108120 parser .add_argument (
109121 "--max_eval_samples" ,
110122 type = int ,
@@ -117,6 +129,12 @@ def main(args) -> None:
117129 action = "store_false" ,
118130 help = "Choose whether you'd like to download the entire dataset or stream it during the evaluation." ,
119131 )
132+ parser .add_argument (
133+ "--warmup_steps" ,
134+ type = int ,
135+ default = 5 ,
136+ help = "Number of warm-up steps to run before launching the timed runs." ,
137+ )
120138 args = parser .parse_args ()
121139 parser .set_defaults (streaming = False )
122140
0 commit comments