1+ # This script is used to evaluate NeMo ASR models on the Multi-Lingual datasets
2+
3+ import argparse
4+ import io
5+ import os
6+ import torch
7+ import evaluate
8+ import soundfile
9+ import numpy as np
10+ from tqdm import tqdm
11+ from datasets import load_dataset
12+ from normalizer import data_utils
13+ from nemo .collections .asr .models import ASRModel
14+ import time
15+
16+
17+ wer_metric = evaluate .load ("wer" )
18+
19+
20+ def main (args ):
21+ DATA_CACHE_DIR = os .path .join (os .getcwd (), "audio_cache" )
22+ CONFIG_NAME = args .config_name
23+ SPLIT_NAME = args .split
24+
25+ # Extract language from config_name if not provided
26+ if args .language :
27+ LANGUAGE = args .language
28+ else :
29+ # Extract language from config_name (e.g., "fleurs_en" -> "en")
30+ try :
31+ LANGUAGE = CONFIG_NAME .split ('_' , 1 )[1 ]
32+ except IndexError :
33+ LANGUAGE = "en" # Default fallback
34+
35+ print (f"Detected language: { LANGUAGE } " )
36+
37+ CACHE_DIR = os .path .join (DATA_CACHE_DIR , CONFIG_NAME , SPLIT_NAME )
38+ if not os .path .exists (CACHE_DIR ):
39+ os .makedirs (CACHE_DIR )
40+
41+ if args .device >= 0 :
42+ device = torch .device (f"cuda:{ args .device } " )
43+ compute_dtype = torch .bfloat16
44+ else :
45+ device = torch .device ("cpu" )
46+ compute_dtype = torch .float32
47+
48+ # Load ASR model
49+ if args .model_id .endswith (".nemo" ):
50+ asr_model = ASRModel .restore_from (args .model_id , map_location = device )
51+ else :
52+ asr_model = ASRModel .from_pretrained (args .model_id , map_location = device )
53+
54+ asr_model .to (compute_dtype )
55+ asr_model .eval ()
56+
57+ # Load dataset using the HuggingFace dataset repository
58+ print (f"Loading dataset: { args .dataset } with config: { CONFIG_NAME } " )
59+
60+ dataset = load_dataset (args .dataset , CONFIG_NAME , split = SPLIT_NAME , streaming = args .streaming )
61+
62+ if args .max_eval_samples is not None and args .max_eval_samples > 0 :
63+ print (f"Subsampling dataset to first { args .max_eval_samples } samples!" )
64+ dataset = dataset .select (range (min (args .max_eval_samples , len (dataset ))))
65+
66+ # Configure decoding strategy
67+ if asr_model .cfg .decoding .strategy != "beam" :
68+ asr_model .cfg .decoding .strategy = "greedy_batch"
69+ asr_model .change_decoding_strategy (asr_model .cfg .decoding )
70+
71+ def download_audio_files (batch ):
72+ """Process audio files and prepare them for evaluation."""
73+ audio_paths = []
74+ durations = []
75+
76+ for i , (file_name , sample , duration , text ) in enumerate (zip (
77+ batch ["file_name" ], batch ["audio" ], batch ["duration" ], batch ["text" ]
78+ )):
79+ # Create unique filename using index to avoid conflicts
80+ unique_id = f"{ CONFIG_NAME } _{ i } _{ os .path .basename (file_name ).replace ('.wav' , '' )} "
81+ audio_path = os .path .join (CACHE_DIR , f"{ unique_id } .wav" )
82+
83+ if "array" in sample :
84+ audio_array = np .float32 (sample ["array" ])
85+ sample_rate = sample .get ("sampling_rate" , 16000 )
86+ elif "bytes" in sample :
87+ with io .BytesIO (sample ["bytes" ]) as audio_file :
88+ audio_array , sample_rate = soundfile .read (audio_file , dtype = "float32" )
89+ else :
90+ raise ValueError ("Sample must have either 'array' or 'bytes' key" )
91+
92+ if not os .path .exists (audio_path ):
93+ os .makedirs (os .path .dirname (audio_path ), exist_ok = True )
94+ soundfile .write (audio_path , audio_array , sample_rate )
95+
96+ audio_paths .append (audio_path )
97+ # Use duration from dataset if available, otherwise calculate
98+ if duration is not None :
99+ durations .append (duration )
100+ else :
101+ durations .append (len (audio_array ) / sample_rate )
102+
103+ batch ["references" ] = [text for text in batch ["text" ]]
104+ batch ["audio_filepaths" ] = audio_paths
105+ batch ["durations" ] = durations
106+
107+ return batch
108+
109+ # Process the dataset
110+ print ("Processing audio files..." )
111+ dataset = dataset .map (
112+ download_audio_files ,
113+ batch_size = args .batch_size ,
114+ batched = True ,
115+ remove_columns = ["audio" ]
116+ )
117+
118+ # Collect all data
119+ all_data = {
120+ "audio_filepaths" : [],
121+ "durations" : [],
122+ "references" : [],
123+ }
124+
125+ print ("Collecting data..." )
126+ for data in tqdm (dataset , desc = "Collecting samples" ):
127+ all_data ["audio_filepaths" ].append (data ["audio_filepaths" ])
128+ all_data ["durations" ].append (data ["durations" ])
129+ all_data ["references" ].append (data ["references" ])
130+
131+ # Sort by duration for efficient batch processing
132+ print ("Sorting by duration..." )
133+ sorted_indices = sorted (range (len (all_data ["durations" ])), key = lambda k : all_data ["durations" ][k ], reverse = True )
134+ all_data ["audio_filepaths" ] = [all_data ["audio_filepaths" ][i ] for i in sorted_indices ]
135+ all_data ["references" ] = [all_data ["references" ][i ] for i in sorted_indices ]
136+ all_data ["durations" ] = [all_data ["durations" ][i ] for i in sorted_indices ]
137+
138+ # Run evaluation with warmup
139+ total_time = 0
140+ for warmup_round in range (2 ): # warmup once and calculate rtf
141+ if warmup_round == 0 :
142+ audio_files = all_data ["audio_filepaths" ][:args .batch_size * 4 ] # warmup with 4 batches
143+ print ("Running warmup..." )
144+ else :
145+ audio_files = all_data ["audio_filepaths" ]
146+ print ("Running full evaluation..." )
147+
148+ start_time = time .time ()
149+ with torch .inference_mode (), torch .no_grad ():
150+ # for canary-1b and canary-1b-flash, we need to set pnc='no' for English and for other languages, we need to set pnc='pnc' but for canary-1b-v2 pnc='yes' for all languages
151+ if 'canary' in args .model_id and 'v2' not in args .model_id :
152+ pnc = 'nopnc' if LANGUAGE == "en" else 'pnc'
153+ else :
154+ pnc = 'pnc'
155+
156+ if 'canary' in args .model_id :
157+ transcriptions = asr_model .transcribe (audio_files , batch_size = args .batch_size , verbose = False , pnc = pnc , num_workers = 1 , source_lang = LANGUAGE , target_lang = LANGUAGE )
158+ else :
159+ transcriptions = asr_model .transcribe (audio_files , batch_size = args .batch_size , verbose = False , num_workers = 1 )
160+ end_time = time .time ()
161+
162+ if warmup_round == 1 :
163+ total_time = end_time - start_time
164+
165+ # Process transcriptions
166+ if isinstance (transcriptions , tuple ) and len (transcriptions ) == 2 :
167+ transcriptions = transcriptions [0 ]
168+
169+ references = all_data ["references" ]
170+ if LANGUAGE == "en" : # English is handled by the English normalizer
171+ references = [data_utils .normalizer (ref ) for ref in references ]
172+ predictions = [data_utils .normalizer (pred .text ) for pred in transcriptions ]
173+ else :
174+ references = [data_utils .ml_normalizer (ref ) for ref in references ]
175+ predictions = [data_utils .ml_normalizer (pred .text ) for pred in transcriptions ]
176+
177+ avg_time = total_time / len (all_data ["audio_filepaths" ])
178+
179+ # Write results using eval_utils.write_manifest
180+ manifest_path = data_utils .write_manifest (
181+ references ,
182+ predictions ,
183+ args .model_id ,
184+ args .dataset , # dataset_path for filename
185+ CONFIG_NAME , # dataset_name
186+ SPLIT_NAME ,
187+ audio_length = all_data ["durations" ],
188+ transcription_time = [avg_time ] * len (all_data ["audio_filepaths" ]),
189+ )
190+
191+ print ("Results saved at path:" , os .path .abspath (manifest_path ))
192+
193+ # Calculate metrics
194+ wer = wer_metric .compute (references = references , predictions = predictions )
195+ wer = round (100 * wer , 2 )
196+
197+ audio_length = sum (all_data ["durations" ])
198+ rtfx = audio_length / total_time
199+ rtfx = round (rtfx , 2 )
200+
201+ print (f"Dataset: { args .dataset } " )
202+ print (f"Language: { LANGUAGE } " )
203+ print (f"Config: { CONFIG_NAME } " )
204+ print (f"Model: { args .model_id } " )
205+ print (f"RTFX: { rtfx } " )
206+ print (f"WER: { wer } %" )
207+
208+
209+ if __name__ == "__main__" :
210+ parser = argparse .ArgumentParser ()
211+
212+ parser .add_argument (
213+ "--model_id" , type = str , required = True , help = "Model identifier. Should be loadable with NVIDIA NeMo." ,
214+ )
215+ parser .add_argument (
216+ "--dataset" ,
217+ type = str ,
218+ default = "nithinraok/asr-leaderboard-datasets" ,
219+ help = "Dataset name. Default is 'nithinraok/asr-leaderboard-datasets'"
220+ )
221+ parser .add_argument (
222+ "--config_name" ,
223+ type = str ,
224+ required = True ,
225+ help = "Config name in format <dataset>_<lang> (e.g., fleurs_en, mcv_de, mls_es)"
226+ )
227+ parser .add_argument (
228+ "--language" ,
229+ type = str ,
230+ default = None ,
231+ help = "Language code (e.g., en, de, es). If not provided, will be extracted from config_name."
232+ )
233+ parser .add_argument (
234+ "--split" ,
235+ type = str ,
236+ default = "test" ,
237+ help = "Split of the dataset. Default is 'test'." ,
238+ )
239+ parser .add_argument (
240+ "--device" ,
241+ type = int ,
242+ default = - 1 ,
243+ help = "The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on." ,
244+ )
245+ parser .add_argument (
246+ "--batch_size" , type = int , default = 32 , help = "Number of samples to go through each streamed batch." ,
247+ )
248+ parser .add_argument (
249+ "--max_eval_samples" ,
250+ type = int ,
251+ default = None ,
252+ help = "Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script." ,
253+ )
254+
255+ parser .add_argument (
256+ "--no-streaming" ,
257+ dest = 'streaming' ,
258+ action = "store_false" ,
259+ help = "Choose whether you'd like to download the entire dataset or stream it during the evaluation." ,
260+ )
261+ args = parser .parse_args ()
262+ parser .set_defaults (streaming = True )
263+
264+ main (args )
0 commit comments