Skip to content

Commit b1f5dbf

Browse files
Add evals script for liteASR
1 parent b4de2c9 commit b1f5dbf

File tree

2 files changed

+334
-0
lines changed

2 files changed

+334
-0
lines changed

liteASR/run_eval.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
import argparse
2+
import os
3+
import torch
4+
from torch.nn.attention import sdpa_kernel, SDPBackend
5+
from transformers import AutoConfig, AutoModel, AutoModelForCTC, AutoProcessor, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
6+
import evaluate
7+
from normalizer import data_utils
8+
import time
9+
from tqdm import tqdm
10+
11+
wer_metric = evaluate.load("wer")
12+
torch.set_float32_matmul_precision('high')
13+
14+
def main(args):
15+
model = AutoModel.from_pretrained(args.model_id, torch_dtype=torch.bfloat16, trust_remote_code=True).to(args.device)
16+
processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
17+
model_input_name = processor.model_input_names[0]
18+
19+
if model.can_generate():
20+
gen_kwargs = {"max_new_tokens": args.max_new_tokens}
21+
# for multilingual Whisper-checkpoints we see a definitive WER boost by setting the language and task args
22+
# print(model.generation_config)
23+
# if getattr(model.generation_config, "is_multilingual"):
24+
# gen_kwargs["language"] = "en"
25+
# gen_kwargs["task"] = "transcribe"
26+
elif args.max_new_tokens:
27+
raise ValueError("`max_new_tokens` should only be set for auto-regressive models, but got a CTC model.")
28+
29+
if args.torch_compile:
30+
model.forward = torch.compile(model.forward, mode=args.compile_mode, fullgraph=True)
31+
if model.can_generate():
32+
# enable static k/v cache for autoregressive models
33+
model.generation_config.cache_implementation = "static"
34+
35+
def benchmark(batch, min_new_tokens=None):
36+
# Load audio inputs
37+
audios = [audio["array"] for audio in batch["audio"]]
38+
minibatch_size = len(audios)
39+
40+
# START TIMING
41+
start_time = time.time()
42+
43+
# 1. Pre-Processing
44+
# 1.1 Pad audios to max batch size if using torch compile to prevent re-compilations
45+
padding_size = None
46+
if minibatch_size != args.batch_size and args.torch_compile:
47+
padding_size = args.batch_size - minibatch_size
48+
padding_audios = [audios[-1] for _ in range(padding_size)]
49+
audios.extend(padding_audios)
50+
51+
if not model.can_generate(): #or len(audios[0]) > processor.feature_extractor.n_samples:
52+
# 1.2 Either CTC pre-processing (normalize to mean 0, std 1), or long-form Whisper processing
53+
inputs = processor(
54+
audios,
55+
sampling_rate=16_000,
56+
truncation=False,
57+
padding="longest",
58+
return_tensors="pt",
59+
return_attention_mask=True,
60+
)
61+
else:
62+
# 1.3 Standard Whisper processing: pad audios to 30-seconds and converted to log-mel
63+
inputs = processor(audios, sampling_rate=16_000, return_tensors="pt", device=args.device)
64+
65+
inputs = inputs.to(args.device)
66+
inputs[model_input_name] = inputs[model_input_name].to(torch.bfloat16)
67+
68+
# 2. Model Inference
69+
with sdpa_kernel(SDPBackend.MATH if args.torch_compile else SDPBackend.FLASH_ATTENTION):
70+
if model.can_generate():
71+
# 2.1 Auto-regressive generation for encoder-decoder models
72+
pred_ids = model.generate(**inputs, **gen_kwargs, min_new_tokens=min_new_tokens)
73+
else:
74+
# 2.2. Single forward pass for CTC
75+
with torch.no_grad():
76+
logits = model(**inputs).logits
77+
pred_ids = logits.argmax(-1)
78+
79+
# 3. Post-processing
80+
# 3.1 Strip padded ids from predictions
81+
if padding_size is not None:
82+
pred_ids = pred_ids[:-padding_size, ...]
83+
84+
# 3.2 Convert token ids to text transcription
85+
pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True)
86+
87+
# END TIMING
88+
runtime = time.time() - start_time
89+
90+
# normalize by minibatch size since we want the per-sample time
91+
batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
92+
93+
# normalize transcriptions with English normalizer
94+
batch["predictions"] = [data_utils.normalizer(pred) for pred in pred_text]
95+
batch["references"] = batch["norm_text"]
96+
return batch
97+
98+
if args.warmup_steps is not None:
99+
dataset = data_utils.load_data(args)
100+
dataset = data_utils.prepare_data(dataset)
101+
102+
num_warmup_samples = args.warmup_steps * args.batch_size
103+
if args.streaming:
104+
warmup_dataset = dataset.take(num_warmup_samples)
105+
else:
106+
warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
107+
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"min_new_tokens": args.max_new_tokens}))
108+
109+
for _ in tqdm(warmup_dataset, desc="Warming up..."):
110+
continue
111+
112+
dataset = data_utils.load_data(args)
113+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
114+
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
115+
if args.streaming:
116+
dataset = dataset.take(args.max_eval_samples)
117+
else:
118+
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
119+
dataset = data_utils.prepare_data(dataset)
120+
121+
dataset = dataset.map(
122+
benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
123+
)
124+
125+
all_results = {
126+
"audio_length_s": [],
127+
"transcription_time_s": [],
128+
"predictions": [],
129+
"references": [],
130+
}
131+
result_iter = iter(dataset)
132+
for result in tqdm(result_iter, desc="Samples..."):
133+
for key in all_results:
134+
all_results[key].append(result[key])
135+
136+
# Write manifest results (WER and RTFX)
137+
manifest_path = data_utils.write_manifest(
138+
all_results["references"],
139+
all_results["predictions"],
140+
args.model_id,
141+
args.dataset_path,
142+
args.dataset,
143+
args.split,
144+
audio_length=all_results["audio_length_s"],
145+
transcription_time=all_results["transcription_time_s"],
146+
)
147+
print("Results saved at path:", os.path.abspath(manifest_path))
148+
149+
wer = wer_metric.compute(
150+
references=all_results["references"], predictions=all_results["predictions"]
151+
)
152+
wer = round(100 * wer, 2)
153+
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
154+
print("WER:", wer, "%", "RTFx:", rtfx)
155+
156+
157+
if __name__ == "__main__":
158+
parser = argparse.ArgumentParser()
159+
160+
parser.add_argument(
161+
"--model_id",
162+
type=str,
163+
required=True,
164+
help="Model identifier. Should be loadable with 🤗 Transformers",
165+
)
166+
parser.add_argument(
167+
"--dataset_path",
168+
type=str,
169+
default="esb/datasets",
170+
help="Dataset path. By default, it is `esb/datasets`",
171+
)
172+
parser.add_argument(
173+
"--dataset",
174+
type=str,
175+
required=True,
176+
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
177+
"can be found at `https://huggingface.co/datasets/esb/datasets`",
178+
)
179+
parser.add_argument(
180+
"--split",
181+
type=str,
182+
default="test",
183+
help="Split of the dataset. *E.g.* `'validation'` for the dev split, or `'test'` for the test split.",
184+
)
185+
parser.add_argument(
186+
"--device",
187+
type=int,
188+
default=-1,
189+
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
190+
)
191+
parser.add_argument(
192+
"--batch_size",
193+
type=int,
194+
default=16,
195+
help="Number of samples to go through each streamed batch.",
196+
)
197+
parser.add_argument(
198+
"--max_eval_samples",
199+
type=int,
200+
default=None,
201+
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
202+
)
203+
parser.add_argument(
204+
"--no-streaming",
205+
dest="streaming",
206+
action="store_false",
207+
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
208+
)
209+
parser.add_argument(
210+
"--max_new_tokens",
211+
type=int,
212+
default=None,
213+
help="Maximum number of tokens to generate (for auto-regressive models).",
214+
)
215+
parser.add_argument(
216+
"--torch_compile",
217+
action="store_true",
218+
help="Whether to JIT compile the forward pass of the model.",
219+
)
220+
parser.add_argument(
221+
"--compile_mode",
222+
type=str,
223+
default="max-autotune",
224+
help="Mode for torch compiling model forward pass. Can be either 'default', 'reduce-overhead', 'max-autotune' or 'max-autotune-no-cudagraphs'.",
225+
)
226+
parser.add_argument(
227+
"--warmup_steps",
228+
type=int,
229+
default=10,
230+
help="Number of warm-up steps to run before launching the timed runs.",
231+
)
232+
args = parser.parse_args()
233+
parser.set_defaults(streaming=False)
234+
235+
main(args)

liteASR/run_liteasr.sh

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#!/bin/bash
2+
3+
export PYTHONPATH="..":$PYTHONPATH
4+
5+
MODEL_IDs=(
6+
"efficient-speech/lite-whisper-large-v3-acc"
7+
"efficient-speech/lite-whisper-large-v3"
8+
"efficient-speech/lite-whisper-large-v3-fast"
9+
"efficient-speech/lite-whisper-large-v3-turbo-acc"
10+
"efficient-speech/lite-whisper-large-v3-turbo"
11+
"efficient-speech/lite-whisper-large-v3-turbo-fast"
12+
)
13+
BATCH_SIZE=64
14+
15+
num_models=${#MODEL_IDs[@]}
16+
17+
for (( i=0; i<${num_models}; i++ ));
18+
do
19+
MODEL_ID=${MODEL_IDs[$i]}
20+
21+
# python run_eval.py \
22+
# --model_id=${MODEL_ID} \
23+
# --dataset_path="hf-audio/esb-datasets-test-only-sorted" \
24+
# --dataset="voxpopuli" \
25+
# --split="test" \
26+
# --device=0 \
27+
# --batch_size=${BATCH_SIZE} \
28+
# --max_eval_samples=-1
29+
30+
# python run_eval.py \
31+
# --model_id=${MODEL_ID} \
32+
# --dataset_path="hf-audio/esb-datasets-test-only-sorted" \
33+
# --dataset="ami" \
34+
# --split="test" \
35+
# --device=0 \
36+
# --batch_size=${BATCH_SIZE} \
37+
# --max_eval_samples=-1
38+
39+
# python run_eval.py \
40+
# --model_id=${MODEL_ID} \
41+
# --dataset_path="hf-audio/esb-datasets-test-only-sorted" \
42+
# --dataset="earnings22" \
43+
# --split="test" \
44+
# --device=0 \
45+
# --batch_size=${BATCH_SIZE} \
46+
# --max_eval_samples=-1
47+
48+
# python run_eval.py \
49+
# --model_id=${MODEL_ID} \
50+
# --dataset_path="hf-audio/esb-datasets-test-only-sorted" \
51+
# --dataset="gigaspeech" \
52+
# --split="test" \
53+
# --device=0 \
54+
# --batch_size=${BATCH_SIZE} \
55+
# --max_eval_samples=-1
56+
57+
# python run_eval.py \
58+
# --model_id=${MODEL_ID} \
59+
# --dataset_path="hf-audio/esb-datasets-test-only-sorted" \
60+
# --dataset="librispeech" \
61+
# --split="test.clean" \
62+
# --device=0 \
63+
# --batch_size=${BATCH_SIZE} \
64+
# --max_eval_samples=-1
65+
66+
# python run_eval.py \
67+
# --model_id=${MODEL_ID} \
68+
# --dataset_path="hf-audio/esb-datasets-test-only-sorted" \
69+
# --dataset="librispeech" \
70+
# --split="test.other" \
71+
# --device=0 \
72+
# --batch_size=${BATCH_SIZE} \
73+
# --max_eval_samples=-1
74+
75+
# python run_eval.py \
76+
# --model_id=${MODEL_ID} \
77+
# --dataset_path="hf-audio/esb-datasets-test-only-sorted" \
78+
# --dataset="spgispeech" \
79+
# --split="test" \
80+
# --device=0 \
81+
# --batch_size=${BATCH_SIZE} \
82+
# --max_eval_samples=-1
83+
84+
# python run_eval.py \
85+
# --model_id=${MODEL_ID} \
86+
# --dataset_path="hf-audio/esb-datasets-test-only-sorted" \
87+
# --dataset="tedlium" \
88+
# --split="test" \
89+
# --device=0 \
90+
# --batch_size=${BATCH_SIZE} \
91+
# --max_eval_samples=-1
92+
93+
# Evaluate results
94+
RUNDIR=`pwd` && \
95+
cd ../normalizer && \
96+
python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \
97+
cd $RUNDIR
98+
99+
done

0 commit comments

Comments
 (0)