Skip to content

Commit b5f3da7

Browse files
committed
add canary-v2
Signed-off-by: nithinraok <[email protected]>
1 parent 64a4f34 commit b5f3da7

File tree

4 files changed

+375
-96
lines changed

4 files changed

+375
-96
lines changed

nemo_asr/run_eval_ml.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
# This script is used to evaluate NeMo ASR models on the Multi-Lingual datasets
2+
3+
import argparse
4+
import io
5+
import os
6+
import torch
7+
import evaluate
8+
import soundfile
9+
import numpy as np
10+
from tqdm import tqdm
11+
from datasets import load_dataset
12+
from normalizer import data_utils
13+
from nemo.collections.asr.models import ASRModel
14+
import time
15+
import re
16+
17+
18+
wer_metric = evaluate.load("wer")
19+
20+
def normalize_text(text):
21+
"""Simple text normalization for non english languages"""
22+
if text is None:
23+
return ""
24+
# Remove capitalization
25+
text = text.lower()
26+
27+
# Remove punctuation
28+
text = re.sub(r'[^\w\s]', '', text)
29+
30+
# Remove extra spaces
31+
text = re.sub(r'\s+', ' ', text).strip()
32+
return text
33+
34+
def main(args):
35+
DATA_CACHE_DIR = os.path.join(os.getcwd(), "audio_cache")
36+
CONFIG_NAME = args.config_name
37+
SPLIT_NAME = args.split
38+
39+
# Extract language from config_name if not provided
40+
if args.language:
41+
LANGUAGE = args.language
42+
else:
43+
# Extract language from config_name (e.g., "fleurs_en" -> "en")
44+
try:
45+
LANGUAGE = CONFIG_NAME.split('_', 1)[1]
46+
except IndexError:
47+
LANGUAGE = "en" # Default fallback
48+
49+
print(f"Detected language: {LANGUAGE}")
50+
51+
CACHE_DIR = os.path.join(DATA_CACHE_DIR, CONFIG_NAME, SPLIT_NAME)
52+
if not os.path.exists(CACHE_DIR):
53+
os.makedirs(CACHE_DIR)
54+
55+
if args.device >= 0:
56+
device = torch.device(f"cuda:{args.device}")
57+
compute_dtype = torch.bfloat16
58+
else:
59+
device = torch.device("cpu")
60+
compute_dtype = torch.float32
61+
62+
# Load ASR model
63+
if args.model_id.endswith(".nemo"):
64+
asr_model = ASRModel.restore_from(args.model_id, map_location=device)
65+
else:
66+
asr_model = ASRModel.from_pretrained(args.model_id, map_location=device)
67+
68+
asr_model.to(compute_dtype)
69+
asr_model.eval()
70+
71+
# Load dataset using the HuggingFace dataset repository
72+
print(f"Loading dataset: {args.dataset} with config: {CONFIG_NAME}")
73+
74+
dataset = load_dataset(args.dataset, CONFIG_NAME, split=SPLIT_NAME, streaming=args.streaming)
75+
76+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
77+
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
78+
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
79+
80+
# Configure decoding strategy
81+
if asr_model.cfg.decoding.strategy != "beam":
82+
asr_model.cfg.decoding.strategy = "greedy_batch"
83+
asr_model.change_decoding_strategy(asr_model.cfg.decoding)
84+
85+
def download_audio_files(batch):
86+
"""Process audio files and prepare them for evaluation."""
87+
audio_paths = []
88+
durations = []
89+
90+
for i, (file_name, sample, duration, text) in enumerate(zip(
91+
batch["file_name"], batch["audio"], batch["duration"], batch["text"]
92+
)):
93+
# Create unique filename using index to avoid conflicts
94+
unique_id = f"{CONFIG_NAME}_{i}_{os.path.basename(file_name).replace('.wav', '')}"
95+
audio_path = os.path.join(CACHE_DIR, f"{unique_id}.wav")
96+
97+
if "array" in sample:
98+
audio_array = np.float32(sample["array"])
99+
sample_rate = sample.get("sampling_rate", 16000)
100+
elif "bytes" in sample:
101+
with io.BytesIO(sample["bytes"]) as audio_file:
102+
audio_array, sample_rate = soundfile.read(audio_file, dtype="float32")
103+
else:
104+
raise ValueError("Sample must have either 'array' or 'bytes' key")
105+
106+
if not os.path.exists(audio_path):
107+
os.makedirs(os.path.dirname(audio_path), exist_ok=True)
108+
soundfile.write(audio_path, audio_array, sample_rate)
109+
110+
audio_paths.append(audio_path)
111+
# Use duration from dataset if available, otherwise calculate
112+
if duration is not None:
113+
durations.append(duration)
114+
else:
115+
durations.append(len(audio_array) / sample_rate)
116+
117+
batch["references"] = [text for text in batch["text"]]
118+
batch["audio_filepaths"] = audio_paths
119+
batch["durations"] = durations
120+
121+
return batch
122+
123+
# Process the dataset
124+
print("Processing audio files...")
125+
dataset = dataset.map(
126+
download_audio_files,
127+
batch_size=args.batch_size,
128+
batched=True,
129+
remove_columns=["audio"]
130+
)
131+
132+
# Collect all data
133+
all_data = {
134+
"audio_filepaths": [],
135+
"durations": [],
136+
"references": [],
137+
}
138+
139+
print("Collecting data...")
140+
for data in tqdm(dataset, desc="Collecting samples"):
141+
all_data["audio_filepaths"].append(data["audio_filepaths"])
142+
all_data["durations"].append(data["durations"])
143+
all_data["references"].append(data["references"])
144+
145+
# Sort by duration for efficient batch processing
146+
print("Sorting by duration...")
147+
sorted_indices = sorted(range(len(all_data["durations"])), key=lambda k: all_data["durations"][k], reverse=True)
148+
all_data["audio_filepaths"] = [all_data["audio_filepaths"][i] for i in sorted_indices]
149+
all_data["references"] = [all_data["references"][i] for i in sorted_indices]
150+
all_data["durations"] = [all_data["durations"][i] for i in sorted_indices]
151+
152+
# Run evaluation with warmup
153+
total_time = 0
154+
for warmup_round in range(2): # warmup once and calculate rtf
155+
if warmup_round == 0:
156+
audio_files = all_data["audio_filepaths"][:args.batch_size * 4] # warmup with 4 batches
157+
print("Running warmup...")
158+
else:
159+
audio_files = all_data["audio_filepaths"]
160+
print("Running full evaluation...")
161+
162+
start_time = time.time()
163+
with torch.autocast(device_type="cuda", dtype=compute_dtype), torch.inference_mode(), torch.no_grad():
164+
# for canary-1b and canary-1b-flash, we need to set pnc='no' for English and for other languages, we need to set pnc='pnc' but for canary-1b-v2 pnc='yes' for all languages
165+
if 'canary' in args.model_id and 'v2' not in args.model_id:
166+
pnc = 'nopnc' if LANGUAGE == "en" else 'pnc'
167+
else:
168+
pnc = 'pnc'
169+
170+
if 'canary' in args.model_id:
171+
transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, pnc=pnc, num_workers=1, source_lang=LANGUAGE, target_lang=LANGUAGE)
172+
else:
173+
transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, num_workers=1)
174+
end_time = time.time()
175+
176+
if warmup_round == 1:
177+
total_time = end_time - start_time
178+
179+
# Process transcriptions
180+
if isinstance(transcriptions, tuple) and len(transcriptions) == 2:
181+
transcriptions = transcriptions[0]
182+
183+
references = all_data["references"]
184+
references = [normalize_text(ref) for ref in references]
185+
predictions = [normalize_text(pred.text) for pred in transcriptions]
186+
187+
avg_time = total_time / len(all_data["audio_filepaths"])
188+
189+
# Write results using eval_utils.write_manifest
190+
manifest_path = data_utils.write_manifest(
191+
references,
192+
predictions,
193+
args.model_id,
194+
args.dataset, # dataset_path for filename
195+
CONFIG_NAME, # dataset_name
196+
SPLIT_NAME,
197+
audio_length=all_data["durations"],
198+
transcription_time=[avg_time] * len(all_data["audio_filepaths"]),
199+
)
200+
201+
print("Results saved at path:", os.path.abspath(manifest_path))
202+
203+
# Calculate metrics
204+
wer = wer_metric.compute(references=references, predictions=predictions)
205+
wer = round(100 * wer, 2)
206+
207+
audio_length = sum(all_data["durations"])
208+
rtfx = audio_length / total_time
209+
rtfx = round(rtfx, 2)
210+
211+
print(f"Dataset: {args.dataset}")
212+
print(f"Language: {LANGUAGE}")
213+
print(f"Config: {CONFIG_NAME}")
214+
print(f"Model: {args.model_id}")
215+
print(f"RTFX: {rtfx}")
216+
print(f"WER: {wer}%")
217+
218+
219+
if __name__ == "__main__":
220+
parser = argparse.ArgumentParser()
221+
222+
parser.add_argument(
223+
"--model_id", type=str, required=True, help="Model identifier. Should be loadable with NVIDIA NeMo.",
224+
)
225+
parser.add_argument(
226+
"--dataset",
227+
type=str,
228+
default="nithinraok/asr-leaderboard-datasets",
229+
help="Dataset name. Default is 'nithinraok/asr-leaderboard-datasets'"
230+
)
231+
parser.add_argument(
232+
"--config_name",
233+
type=str,
234+
required=True,
235+
help="Config name in format <dataset>_<lang> (e.g., fleurs_en, mcv_de, mls_es)"
236+
)
237+
parser.add_argument(
238+
"--language",
239+
type=str,
240+
default=None,
241+
help="Language code (e.g., en, de, es). If not provided, will be extracted from config_name."
242+
)
243+
parser.add_argument(
244+
"--split",
245+
type=str,
246+
default="test",
247+
help="Split of the dataset. Default is 'test'.",
248+
)
249+
parser.add_argument(
250+
"--device",
251+
type=int,
252+
default=-1,
253+
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
254+
)
255+
parser.add_argument(
256+
"--batch_size", type=int, default=32, help="Number of samples to go through each streamed batch.",
257+
)
258+
parser.add_argument(
259+
"--max_eval_samples",
260+
type=int,
261+
default=None,
262+
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
263+
)
264+
265+
parser.add_argument(
266+
"--no-streaming",
267+
dest='streaming',
268+
action="store_false",
269+
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
270+
)
271+
args = parser.parse_args()
272+
parser.set_defaults(streaming=True)
273+
274+
main(args)

nemo_asr/run_fast_conformer_ctc.sh

Lines changed: 0 additions & 95 deletions
This file was deleted.

0 commit comments

Comments
 (0)