22
33Authors
44* Adel Moumen 2023 <[email protected] > 5+ * Sanchit Gandhi 2024 <[email protected] > 56"""
67import argparse
8+ import time
79
810import evaluate
911from normalizer import data_utils
1012from tqdm import tqdm
1113import torch
12- import speechbrain .pretrained as pretrained
14+ import speechbrain .inference . ASR as ASR
1315from speechbrain .utils .data_utils import batch_pad_right
14- from datasets import Dataset
15- from typing import List , Union
16- import os
16+ import os
1717
1818def get_model (
1919 speechbrain_repository : str ,
@@ -61,7 +61,7 @@ def get_model(
6161 }
6262
6363 try :
64- model_class = getattr (pretrained , speechbrain_pretrained_class_name )
64+ model_class = getattr (ASR , speechbrain_pretrained_class_name )
6565 except AttributeError :
6666 raise AttributeError (
6767 f"SpeechBrain Pretrained class: { speechbrain_pretrained_class_name } not found in pretrained.py"
@@ -70,137 +70,100 @@ def get_model(
7070 return model_class .from_hparams (** kwargs )
7171
7272
73- def dataset_iterator (dataset : Dataset ):
74- """Iterate over the dataset and yield the audio and reference text.
75-
76- Arguments
77- ---------
78- dataset : Dataset
79- The dataset to iterate over.
80-
81- Yields
82- ------
83- dict
84- A dictionary containing the audio and reference text.
85- """
86- for i , item in enumerate (dataset ):
87- yield {
88- ** item ["audio" ],
89- "reference" : item ["norm_text" ],
90- "audio_filename" : f"file_{ i } " ,
91- "sample_rate" : 16_000 ,
92- "sample_id" : i ,
93- }
73+ def main (args ):
74+ """Run the evaluation script."""
75+ if args .device == - 1 :
76+ device = "cpu"
77+ else :
78+ device = f"cuda:{ args .device } "
9479
80+ model = get_model (
81+ args .source , args .speechbrain_pretrained_class_name , device = device
82+ )
9583
96- def evaluate_batch (model , buffer : List , predictions : List , device : str ) -> None :
97- """Evaluate a batch of audio samples.
84+ def benchmark (batch ):
85+ # Load audio inputs
86+ audios = [torch .from_numpy (sample ["array" ]) for sample in batch ["audio" ]]
87+ minibatch_size = len (audios )
9888
99- Arguments
100- ---------
101- model : Pretrained
102- The SpeechBrain pretrained model.
103- buffer : List
104- A list of audio samples.
105- predictions : List
106- A list of predictions.
107- device : str
108- The device to run the model on.
109- """
110- wavs = [torch .from_numpy (sample ["array" ]) for sample in buffer ]
111- wavs , wav_lens = batch_pad_right (wavs )
112- wavs = wavs .to (device )
113- wav_lens = wav_lens .to (device )
114- predicted_words , _ = model .transcribe_batch (wavs , wav_lens )
89+ # START TIMING
90+ start_time = time .time ()
11591
116- for result in predicted_words :
117- result = data_utils . normalizer ( result )
118- predictions . append ( result )
119- buffer . clear ( )
92+ audios , audio_lens = batch_pad_right ( audios )
93+ audios = audios . to ( device )
94+ audio_lens = audio_lens . to ( device )
95+ predictions , _ = model . transcribe_batch ( audios , audio_lens )
12096
97+ # END TIMING
98+ runtime = time .time () - start_time
12199
122- def evaluate_dataset (
123- model , dataset : Dataset , device : str , batch_size : int , verbose : bool = True
124- ) -> Union [List , List ]:
125- """Evaluate a dataset the SpeechBrain pretrained model.
100+ # normalize by minibatch size since we want the per-sample time
101+ batch ["transcription_time_s" ] = minibatch_size * [runtime / minibatch_size ]
126102
127- Arguments
128- ---------
129- model : Pretrained
130- The SpeechBrain pretrained model.
131- dataset : Dataset
132- The dataset to evaluate.
133- device : str
134- The device to run the model on.
135- batch_size : int
136- The batch size to use.
137- verbose : bool, optional
138- Whether to print progress information.
103+ # normalize transcriptions with English normalizer
104+ batch ["predictions" ] = [data_utils .normalizer (pred ) for pred in predictions ]
105+ batch ["references" ] = batch ["norm_text" ]
106+ return batch
139107
140- Returns
141- -------
142- references : List
143- A list of references.
144- predictions : List
145- A list of predictions.
146- """
147- references = []
148- predictions = []
149- buffer = []
150- for sample in tqdm (
151- dataset_iterator (dataset ),
152- desc = "Evaluating: Sample id" ,
153- unit = "" ,
154- disable = not verbose ,
155- ):
156- buffer .append (sample )
157- references .append (sample ["reference" ])
158- if len (buffer ) == batch_size :
159- evaluate_batch (model , buffer , predictions , device )
160-
161- if len (buffer ) > 0 :
162- evaluate_batch (model , buffer , predictions , device )
163-
164- return references , predictions
165108
109+ if args .warmup_steps is not None :
110+ dataset = data_utils .load_data (args )
111+ dataset = data_utils .prepare_data (dataset )
166112
167- def main ( args ):
168- """Run the evaluation script."""
169- if args . device == - 1 :
170- device = "cpu"
171- else :
172- device = f"cuda: { args .device } "
113+ num_warmup_samples = args . warmup_steps * args . batch_size
114+ if args . streaming :
115+ warmup_dataset = dataset . take ( num_warmup_samples )
116+ else :
117+ warmup_dataset = dataset . select ( range ( min ( num_warmup_samples , len ( dataset ))))
118+ warmup_dataset = iter ( warmup_dataset . map ( benchmark , batch_size = args .batch_size , batched = True ))
173119
174- asr_model = get_model (
175- args .source , args .speechbrain_pretrained_class_name , device = device
176- )
120+ for _ in tqdm (warmup_dataset , desc = "Warming up..." ):
121+ continue
177122
178123 dataset = data_utils .load_data (args )
179-
180124 if args .max_eval_samples is not None and args .max_eval_samples > 0 :
181- print (f"Subsampling dataset to first { args .max_eval_samples } samples !" )
182- dataset = dataset .take (args .max_eval_samples )
183-
125+ print (f"Subsampling dataset to first { args .max_eval_samples } samples!" )
126+ if args .streaming :
127+ dataset = dataset .take (args .max_eval_samples )
128+ else :
129+ dataset = dataset .select (range (min (args .max_eval_samples , len (dataset ))))
184130 dataset = data_utils .prepare_data (dataset )
185131
186- predictions = []
187- references = []
188-
189- references , predictions = evaluate_dataset (
190- asr_model , dataset , device , args .batch_size , verbose = True
132+ dataset = dataset .map (
133+ benchmark , batch_size = args .batch_size , batched = True , remove_columns = ["audio" ],
191134 )
192135
193- # Write manifest results
136+ all_results = {
137+ "audio_length_s" : [],
138+ "transcription_time_s" : [],
139+ "predictions" : [],
140+ "references" : [],
141+ }
142+ result_iter = iter (dataset )
143+ for result in tqdm (result_iter , desc = "Samples..." ):
144+ for key in all_results :
145+ all_results [key ].append (result [key ])
146+
147+ # Write manifest results (WER and RTFX)
194148 manifest_path = data_utils .write_manifest (
195- references , predictions , args .source , args .dataset_path , args .dataset , args .split
149+ all_results ["references" ],
150+ all_results ["predictions" ],
151+ args .model_id ,
152+ args .dataset_path ,
153+ args .dataset ,
154+ args .split ,
155+ audio_length = all_results ["audio_length_s" ],
156+ transcription_time = all_results ["transcription_time_s" ],
196157 )
197158 print ("Results saved at path:" , os .path .abspath (manifest_path ))
198-
159+
199160 wer_metric = evaluate .load ("wer" )
200- wer = wer_metric .compute (references = references , predictions = predictions )
161+ wer = wer_metric .compute (
162+ references = all_results ["references" ], predictions = all_results ["predictions" ]
163+ )
201164 wer = round (100 * wer , 2 )
202-
203- print ("WER:" , wer , "%" )
165+ rtfx = round ( sum ( all_results [ "audio_length_s" ]) / sum ( all_results [ "transcription_time_s" ]), 2 )
166+ print ("WER:" , wer , "%" , "RTFx:" , rtfx )
204167
205168
206169if __name__ == "__main__" :
@@ -263,6 +226,12 @@ def main(args):
263226 action = "store_false" ,
264227 help = "Choose whether you'd like to download the entire dataset or stream it during the evaluation." ,
265228 )
229+ parser .add_argument (
230+ "--warmup_steps" ,
231+ type = int ,
232+ default = 5 ,
233+ help = "Number of warm-up steps to run before launching the timed runs." ,
234+ )
266235 args = parser .parse_args ()
267236 parser .set_defaults (streaming = True )
268237
0 commit comments