1+ # This script is used to evaluate NeMo ASR models on the Multi-Lingual datasets
2+
3+ import argparse
4+ import io
5+ import os
6+ import torch
7+ import evaluate
8+ import soundfile
9+ import numpy as np
10+ from tqdm import tqdm
11+ from datasets import load_dataset
12+ from normalizer import data_utils
13+ from nemo .collections .asr .models import ASRModel
14+ import time
15+ import re
16+
17+
18+ wer_metric = evaluate .load ("wer" )
19+
20+ def normalize_text (text ):
21+ """Simple text normalization for non english languages"""
22+ if text is None :
23+ return ""
24+ # Remove capitalization
25+ text = text .lower ()
26+
27+ # Remove punctuation
28+ text = re .sub (r'[^\w\s]' , '' , text )
29+
30+ # Remove extra spaces
31+ text = re .sub (r'\s+' , ' ' , text ).strip ()
32+ return text
33+
34+ def main (args ):
35+ DATA_CACHE_DIR = os .path .join (os .getcwd (), "audio_cache" )
36+ CONFIG_NAME = args .config_name
37+ SPLIT_NAME = args .split
38+
39+ # Extract language from config_name if not provided
40+ if args .language :
41+ LANGUAGE = args .language
42+ else :
43+ # Extract language from config_name (e.g., "fleurs_en" -> "en")
44+ try :
45+ LANGUAGE = CONFIG_NAME .split ('_' , 1 )[1 ]
46+ except IndexError :
47+ LANGUAGE = "en" # Default fallback
48+
49+ print (f"Detected language: { LANGUAGE } " )
50+
51+ CACHE_DIR = os .path .join (DATA_CACHE_DIR , CONFIG_NAME , SPLIT_NAME )
52+ if not os .path .exists (CACHE_DIR ):
53+ os .makedirs (CACHE_DIR )
54+
55+ if args .device >= 0 :
56+ device = torch .device (f"cuda:{ args .device } " )
57+ compute_dtype = torch .bfloat16
58+ else :
59+ device = torch .device ("cpu" )
60+ compute_dtype = torch .float32
61+
62+ # Load ASR model
63+ if args .model_id .endswith (".nemo" ):
64+ asr_model = ASRModel .restore_from (args .model_id , map_location = device )
65+ else :
66+ asr_model = ASRModel .from_pretrained (args .model_id , map_location = device )
67+
68+ asr_model .to (compute_dtype )
69+ asr_model .eval ()
70+
71+ # Load dataset using the HuggingFace dataset repository
72+ print (f"Loading dataset: { args .dataset } with config: { CONFIG_NAME } " )
73+
74+ dataset = load_dataset (args .dataset , CONFIG_NAME , split = SPLIT_NAME , streaming = args .streaming )
75+
76+ if args .max_eval_samples is not None and args .max_eval_samples > 0 :
77+ print (f"Subsampling dataset to first { args .max_eval_samples } samples!" )
78+ dataset = dataset .select (range (min (args .max_eval_samples , len (dataset ))))
79+
80+ # Configure decoding strategy
81+ if asr_model .cfg .decoding .strategy != "beam" :
82+ asr_model .cfg .decoding .strategy = "greedy_batch"
83+ asr_model .change_decoding_strategy (asr_model .cfg .decoding )
84+
85+ def download_audio_files (batch ):
86+ """Process audio files and prepare them for evaluation."""
87+ audio_paths = []
88+ durations = []
89+
90+ for i , (file_name , sample , duration , text ) in enumerate (zip (
91+ batch ["file_name" ], batch ["audio" ], batch ["duration" ], batch ["text" ]
92+ )):
93+ # Create unique filename using index to avoid conflicts
94+ unique_id = f"{ CONFIG_NAME } _{ i } _{ os .path .basename (file_name ).replace ('.wav' , '' )} "
95+ audio_path = os .path .join (CACHE_DIR , f"{ unique_id } .wav" )
96+
97+ if "array" in sample :
98+ audio_array = np .float32 (sample ["array" ])
99+ sample_rate = sample .get ("sampling_rate" , 16000 )
100+ elif "bytes" in sample :
101+ with io .BytesIO (sample ["bytes" ]) as audio_file :
102+ audio_array , sample_rate = soundfile .read (audio_file , dtype = "float32" )
103+ else :
104+ raise ValueError ("Sample must have either 'array' or 'bytes' key" )
105+
106+ if not os .path .exists (audio_path ):
107+ os .makedirs (os .path .dirname (audio_path ), exist_ok = True )
108+ soundfile .write (audio_path , audio_array , sample_rate )
109+
110+ audio_paths .append (audio_path )
111+ # Use duration from dataset if available, otherwise calculate
112+ if duration is not None :
113+ durations .append (duration )
114+ else :
115+ durations .append (len (audio_array ) / sample_rate )
116+
117+ batch ["references" ] = [text for text in batch ["text" ]]
118+ batch ["audio_filepaths" ] = audio_paths
119+ batch ["durations" ] = durations
120+
121+ return batch
122+
123+ # Process the dataset
124+ print ("Processing audio files..." )
125+ dataset = dataset .map (
126+ download_audio_files ,
127+ batch_size = args .batch_size ,
128+ batched = True ,
129+ remove_columns = ["audio" ]
130+ )
131+
132+ # Collect all data
133+ all_data = {
134+ "audio_filepaths" : [],
135+ "durations" : [],
136+ "references" : [],
137+ }
138+
139+ print ("Collecting data..." )
140+ for data in tqdm (dataset , desc = "Collecting samples" ):
141+ all_data ["audio_filepaths" ].append (data ["audio_filepaths" ])
142+ all_data ["durations" ].append (data ["durations" ])
143+ all_data ["references" ].append (data ["references" ])
144+
145+ # Sort by duration for efficient batch processing
146+ print ("Sorting by duration..." )
147+ sorted_indices = sorted (range (len (all_data ["durations" ])), key = lambda k : all_data ["durations" ][k ], reverse = True )
148+ all_data ["audio_filepaths" ] = [all_data ["audio_filepaths" ][i ] for i in sorted_indices ]
149+ all_data ["references" ] = [all_data ["references" ][i ] for i in sorted_indices ]
150+ all_data ["durations" ] = [all_data ["durations" ][i ] for i in sorted_indices ]
151+
152+ # Run evaluation with warmup
153+ total_time = 0
154+ for warmup_round in range (2 ): # warmup once and calculate rtf
155+ if warmup_round == 0 :
156+ audio_files = all_data ["audio_filepaths" ][:args .batch_size * 4 ] # warmup with 4 batches
157+ print ("Running warmup..." )
158+ else :
159+ audio_files = all_data ["audio_filepaths" ]
160+ print ("Running full evaluation..." )
161+
162+ start_time = time .time ()
163+ with torch .autocast (device_type = "cuda" , dtype = compute_dtype ), torch .inference_mode (), torch .no_grad ():
164+ # for canary-1b and canary-1b-flash, we need to set pnc='no' for English and for other languages, we need to set pnc='pnc' but for canary-1b-v2 pnc='yes' for all languages
165+ if 'canary' in args .model_id and 'v2' not in args .model_id :
166+ pnc = 'nopnc' if LANGUAGE == "en" else 'pnc'
167+ else :
168+ pnc = 'pnc'
169+
170+ if 'canary' in args .model_id :
171+ transcriptions = asr_model .transcribe (audio_files , batch_size = args .batch_size , verbose = False , pnc = pnc , num_workers = 1 , source_lang = LANGUAGE , target_lang = LANGUAGE )
172+ else :
173+ transcriptions = asr_model .transcribe (audio_files , batch_size = args .batch_size , verbose = False , num_workers = 1 )
174+ end_time = time .time ()
175+
176+ if warmup_round == 1 :
177+ total_time = end_time - start_time
178+
179+ # Process transcriptions
180+ if isinstance (transcriptions , tuple ) and len (transcriptions ) == 2 :
181+ transcriptions = transcriptions [0 ]
182+
183+ references = all_data ["references" ]
184+ references = [normalize_text (ref ) for ref in references ]
185+ predictions = [normalize_text (pred .text ) for pred in transcriptions ]
186+
187+ avg_time = total_time / len (all_data ["audio_filepaths" ])
188+
189+ # Write results using eval_utils.write_manifest
190+ manifest_path = data_utils .write_manifest (
191+ references ,
192+ predictions ,
193+ args .model_id ,
194+ args .dataset , # dataset_path for filename
195+ CONFIG_NAME , # dataset_name
196+ SPLIT_NAME ,
197+ audio_length = all_data ["durations" ],
198+ transcription_time = [avg_time ] * len (all_data ["audio_filepaths" ]),
199+ )
200+
201+ print ("Results saved at path:" , os .path .abspath (manifest_path ))
202+
203+ # Calculate metrics
204+ wer = wer_metric .compute (references = references , predictions = predictions )
205+ wer = round (100 * wer , 2 )
206+
207+ audio_length = sum (all_data ["durations" ])
208+ rtfx = audio_length / total_time
209+ rtfx = round (rtfx , 2 )
210+
211+ print (f"Dataset: { args .dataset } " )
212+ print (f"Language: { LANGUAGE } " )
213+ print (f"Config: { CONFIG_NAME } " )
214+ print (f"Model: { args .model_id } " )
215+ print (f"RTFX: { rtfx } " )
216+ print (f"WER: { wer } %" )
217+
218+
219+ if __name__ == "__main__" :
220+ parser = argparse .ArgumentParser ()
221+
222+ parser .add_argument (
223+ "--model_id" , type = str , required = True , help = "Model identifier. Should be loadable with NVIDIA NeMo." ,
224+ )
225+ parser .add_argument (
226+ "--dataset" ,
227+ type = str ,
228+ default = "nithinraok/asr-leaderboard-datasets" ,
229+ help = "Dataset name. Default is 'nithinraok/asr-leaderboard-datasets'"
230+ )
231+ parser .add_argument (
232+ "--config_name" ,
233+ type = str ,
234+ required = True ,
235+ help = "Config name in format <dataset>_<lang> (e.g., fleurs_en, mcv_de, mls_es)"
236+ )
237+ parser .add_argument (
238+ "--language" ,
239+ type = str ,
240+ default = None ,
241+ help = "Language code (e.g., en, de, es). If not provided, will be extracted from config_name."
242+ )
243+ parser .add_argument (
244+ "--split" ,
245+ type = str ,
246+ default = "test" ,
247+ help = "Split of the dataset. Default is 'test'." ,
248+ )
249+ parser .add_argument (
250+ "--device" ,
251+ type = int ,
252+ default = - 1 ,
253+ help = "The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on." ,
254+ )
255+ parser .add_argument (
256+ "--batch_size" , type = int , default = 32 , help = "Number of samples to go through each streamed batch." ,
257+ )
258+ parser .add_argument (
259+ "--max_eval_samples" ,
260+ type = int ,
261+ default = None ,
262+ help = "Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script." ,
263+ )
264+
265+ parser .add_argument (
266+ "--no-streaming" ,
267+ dest = 'streaming' ,
268+ action = "store_false" ,
269+ help = "Choose whether you'd like to download the entire dataset or stream it during the evaluation." ,
270+ )
271+ args = parser .parse_args ()
272+ parser .set_defaults (streaming = True )
273+
274+ main (args )
0 commit comments