Skip to content

Commit 4d90150

Browse files
Merge pull request #88 from nithinraok/add_canary-v2
Add canary v2
2 parents 64a4f34 + b921f7a commit 4d90150

File tree

9 files changed

+398
-103
lines changed

9 files changed

+398
-103
lines changed

nemo_asr/run_canary.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
export PYTHONPATH="..":$PYTHONPATH
44

5-
MODEL_IDs=("nvidia/canary-1b-flash") # options: "nvidia/canary-1b" "nvidia/canary-1b-flash" "nvidia/canary-180m-flash"
5+
MODEL_IDs=("nvidia/canary-1b-v2") # options: "nvidia/canary-1b" "nvidia/canary-1b-flash" "nvidia/canary-180m-flash" "nvidia/canary-1b-v2"
66
BATCH_SIZE=128
77
DEVICE_ID=0
88

nemo_asr/run_eval.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,15 @@ def download_audio_files(batch):
127127
else:
128128
audio_files = all_data["audio_filepaths"]
129129
start_time = time.time()
130-
with torch.cuda.amp.autocast(enabled=False, dtype=compute_dtype), torch.inference_mode(), torch.no_grad():
130+
with torch.inference_mode(), torch.no_grad():
131+
132+
if 'canary' in args.model_id and 'v2' not in args.model_id:
133+
pnc = 'nopnc'
134+
else:
135+
pnc = 'pnc'
136+
131137
if 'canary' in args.model_id:
132-
transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, pnc='no', num_workers=1)
138+
transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, pnc=pnc, num_workers=1)
133139
else:
134140
transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, num_workers=1)
135141
end_time = time.time()

nemo_asr/run_eval_ml.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
# This script is used to evaluate NeMo ASR models on the Multi-Lingual datasets
2+
3+
import argparse
4+
import io
5+
import os
6+
import torch
7+
import evaluate
8+
import soundfile
9+
import numpy as np
10+
from tqdm import tqdm
11+
from datasets import load_dataset
12+
from normalizer import data_utils
13+
from nemo.collections.asr.models import ASRModel
14+
import time
15+
16+
17+
wer_metric = evaluate.load("wer")
18+
19+
20+
def main(args):
21+
DATA_CACHE_DIR = os.path.join(os.getcwd(), "audio_cache")
22+
CONFIG_NAME = args.config_name
23+
SPLIT_NAME = args.split
24+
25+
# Extract language from config_name if not provided
26+
if args.language:
27+
LANGUAGE = args.language
28+
else:
29+
# Extract language from config_name (e.g., "fleurs_en" -> "en")
30+
try:
31+
LANGUAGE = CONFIG_NAME.split('_', 1)[1]
32+
except IndexError:
33+
LANGUAGE = "en" # Default fallback
34+
35+
print(f"Detected language: {LANGUAGE}")
36+
37+
CACHE_DIR = os.path.join(DATA_CACHE_DIR, CONFIG_NAME, SPLIT_NAME)
38+
if not os.path.exists(CACHE_DIR):
39+
os.makedirs(CACHE_DIR)
40+
41+
if args.device >= 0:
42+
device = torch.device(f"cuda:{args.device}")
43+
compute_dtype = torch.bfloat16
44+
else:
45+
device = torch.device("cpu")
46+
compute_dtype = torch.float32
47+
48+
# Load ASR model
49+
if args.model_id.endswith(".nemo"):
50+
asr_model = ASRModel.restore_from(args.model_id, map_location=device)
51+
else:
52+
asr_model = ASRModel.from_pretrained(args.model_id, map_location=device)
53+
54+
asr_model.to(compute_dtype)
55+
asr_model.eval()
56+
57+
# Load dataset using the HuggingFace dataset repository
58+
print(f"Loading dataset: {args.dataset} with config: {CONFIG_NAME}")
59+
60+
dataset = load_dataset(args.dataset, CONFIG_NAME, split=SPLIT_NAME, streaming=args.streaming)
61+
62+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
63+
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
64+
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
65+
66+
# Configure decoding strategy
67+
if asr_model.cfg.decoding.strategy != "beam":
68+
asr_model.cfg.decoding.strategy = "greedy_batch"
69+
asr_model.change_decoding_strategy(asr_model.cfg.decoding)
70+
71+
def download_audio_files(batch):
72+
"""Process audio files and prepare them for evaluation."""
73+
audio_paths = []
74+
durations = []
75+
76+
for i, (file_name, sample, duration, text) in enumerate(zip(
77+
batch["file_name"], batch["audio"], batch["duration"], batch["text"]
78+
)):
79+
# Create unique filename using index to avoid conflicts
80+
unique_id = f"{CONFIG_NAME}_{i}_{os.path.basename(file_name).replace('.wav', '')}"
81+
audio_path = os.path.join(CACHE_DIR, f"{unique_id}.wav")
82+
83+
if "array" in sample:
84+
audio_array = np.float32(sample["array"])
85+
sample_rate = sample.get("sampling_rate", 16000)
86+
elif "bytes" in sample:
87+
with io.BytesIO(sample["bytes"]) as audio_file:
88+
audio_array, sample_rate = soundfile.read(audio_file, dtype="float32")
89+
else:
90+
raise ValueError("Sample must have either 'array' or 'bytes' key")
91+
92+
if not os.path.exists(audio_path):
93+
os.makedirs(os.path.dirname(audio_path), exist_ok=True)
94+
soundfile.write(audio_path, audio_array, sample_rate)
95+
96+
audio_paths.append(audio_path)
97+
# Use duration from dataset if available, otherwise calculate
98+
if duration is not None:
99+
durations.append(duration)
100+
else:
101+
durations.append(len(audio_array) / sample_rate)
102+
103+
batch["references"] = [text for text in batch["text"]]
104+
batch["audio_filepaths"] = audio_paths
105+
batch["durations"] = durations
106+
107+
return batch
108+
109+
# Process the dataset
110+
print("Processing audio files...")
111+
dataset = dataset.map(
112+
download_audio_files,
113+
batch_size=args.batch_size,
114+
batched=True,
115+
remove_columns=["audio"]
116+
)
117+
118+
# Collect all data
119+
all_data = {
120+
"audio_filepaths": [],
121+
"durations": [],
122+
"references": [],
123+
}
124+
125+
print("Collecting data...")
126+
for data in tqdm(dataset, desc="Collecting samples"):
127+
all_data["audio_filepaths"].append(data["audio_filepaths"])
128+
all_data["durations"].append(data["durations"])
129+
all_data["references"].append(data["references"])
130+
131+
# Sort by duration for efficient batch processing
132+
print("Sorting by duration...")
133+
sorted_indices = sorted(range(len(all_data["durations"])), key=lambda k: all_data["durations"][k], reverse=True)
134+
all_data["audio_filepaths"] = [all_data["audio_filepaths"][i] for i in sorted_indices]
135+
all_data["references"] = [all_data["references"][i] for i in sorted_indices]
136+
all_data["durations"] = [all_data["durations"][i] for i in sorted_indices]
137+
138+
# Run evaluation with warmup
139+
total_time = 0
140+
for warmup_round in range(2): # warmup once and calculate rtf
141+
if warmup_round == 0:
142+
audio_files = all_data["audio_filepaths"][:args.batch_size * 4] # warmup with 4 batches
143+
print("Running warmup...")
144+
else:
145+
audio_files = all_data["audio_filepaths"]
146+
print("Running full evaluation...")
147+
148+
start_time = time.time()
149+
with torch.inference_mode(), torch.no_grad():
150+
# for canary-1b and canary-1b-flash, we need to set pnc='no' for English and for other languages, we need to set pnc='pnc' but for canary-1b-v2 pnc='yes' for all languages
151+
if 'canary' in args.model_id and 'v2' not in args.model_id:
152+
pnc = 'nopnc' if LANGUAGE == "en" else 'pnc'
153+
else:
154+
pnc = 'pnc'
155+
156+
if 'canary' in args.model_id:
157+
transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, pnc=pnc, num_workers=1, source_lang=LANGUAGE, target_lang=LANGUAGE)
158+
else:
159+
transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, num_workers=1)
160+
end_time = time.time()
161+
162+
if warmup_round == 1:
163+
total_time = end_time - start_time
164+
165+
# Process transcriptions
166+
if isinstance(transcriptions, tuple) and len(transcriptions) == 2:
167+
transcriptions = transcriptions[0]
168+
169+
references = all_data["references"]
170+
if LANGUAGE == "en": # English is handled by the English normalizer
171+
references = [data_utils.normalizer(ref) for ref in references]
172+
predictions = [data_utils.normalizer(pred.text) for pred in transcriptions]
173+
else:
174+
references = [data_utils.ml_normalizer(ref) for ref in references]
175+
predictions = [data_utils.ml_normalizer(pred.text) for pred in transcriptions]
176+
177+
avg_time = total_time / len(all_data["audio_filepaths"])
178+
179+
# Write results using eval_utils.write_manifest
180+
manifest_path = data_utils.write_manifest(
181+
references,
182+
predictions,
183+
args.model_id,
184+
args.dataset, # dataset_path for filename
185+
CONFIG_NAME, # dataset_name
186+
SPLIT_NAME,
187+
audio_length=all_data["durations"],
188+
transcription_time=[avg_time] * len(all_data["audio_filepaths"]),
189+
)
190+
191+
print("Results saved at path:", os.path.abspath(manifest_path))
192+
193+
# Calculate metrics
194+
wer = wer_metric.compute(references=references, predictions=predictions)
195+
wer = round(100 * wer, 2)
196+
197+
audio_length = sum(all_data["durations"])
198+
rtfx = audio_length / total_time
199+
rtfx = round(rtfx, 2)
200+
201+
print(f"Dataset: {args.dataset}")
202+
print(f"Language: {LANGUAGE}")
203+
print(f"Config: {CONFIG_NAME}")
204+
print(f"Model: {args.model_id}")
205+
print(f"RTFX: {rtfx}")
206+
print(f"WER: {wer}%")
207+
208+
209+
if __name__ == "__main__":
210+
parser = argparse.ArgumentParser()
211+
212+
parser.add_argument(
213+
"--model_id", type=str, required=True, help="Model identifier. Should be loadable with NVIDIA NeMo.",
214+
)
215+
parser.add_argument(
216+
"--dataset",
217+
type=str,
218+
default="nithinraok/asr-leaderboard-datasets",
219+
help="Dataset name. Default is 'nithinraok/asr-leaderboard-datasets'"
220+
)
221+
parser.add_argument(
222+
"--config_name",
223+
type=str,
224+
required=True,
225+
help="Config name in format <dataset>_<lang> (e.g., fleurs_en, mcv_de, mls_es)"
226+
)
227+
parser.add_argument(
228+
"--language",
229+
type=str,
230+
default=None,
231+
help="Language code (e.g., en, de, es). If not provided, will be extracted from config_name."
232+
)
233+
parser.add_argument(
234+
"--split",
235+
type=str,
236+
default="test",
237+
help="Split of the dataset. Default is 'test'.",
238+
)
239+
parser.add_argument(
240+
"--device",
241+
type=int,
242+
default=-1,
243+
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
244+
)
245+
parser.add_argument(
246+
"--batch_size", type=int, default=32, help="Number of samples to go through each streamed batch.",
247+
)
248+
parser.add_argument(
249+
"--max_eval_samples",
250+
type=int,
251+
default=None,
252+
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
253+
)
254+
255+
parser.add_argument(
256+
"--no-streaming",
257+
dest='streaming',
258+
action="store_false",
259+
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
260+
)
261+
args = parser.parse_args()
262+
parser.set_defaults(streaming=True)
263+
264+
main(args)

nemo_asr/run_fast_conformer_ctc.sh

Lines changed: 0 additions & 95 deletions
This file was deleted.

0 commit comments

Comments
 (0)