Skip to content

Commit 3953c59

Browse files
Merge pull request #52 from huggingface/liteASR
Add evals script for liteASR
2 parents 0dc4559 + 2582589 commit 3953c59

File tree

2 files changed

+330
-0
lines changed

2 files changed

+330
-0
lines changed

liteASR/run_eval.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
import argparse
2+
import os
3+
import torch
4+
from torch.nn.attention import sdpa_kernel, SDPBackend
5+
from transformers import AutoConfig, AutoModel, AutoModelForCTC, AutoProcessor, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
6+
import evaluate
7+
from normalizer import data_utils
8+
import time
9+
from tqdm import tqdm
10+
11+
wer_metric = evaluate.load("wer")
12+
torch.set_float32_matmul_precision('high')
13+
14+
def main(args):
15+
model = AutoModel.from_pretrained(args.model_id, torch_dtype=torch.float16, trust_remote_code=True, force_download=True).to(args.device)
16+
processor = AutoProcessor.from_pretrained("openai/whisper-large-v3-turbo", force_download=True)
17+
model_input_name = processor.model_input_names[0]
18+
19+
if model.can_generate():
20+
gen_kwargs = {"max_new_tokens": 224}
21+
elif args.max_new_tokens:
22+
raise ValueError("`max_new_tokens` should only be set for auto-regressive models, but got a CTC model.")
23+
24+
if args.torch_compile:
25+
model.forward = torch.compile(model.forward, mode=args.compile_mode, fullgraph=True)
26+
if model.can_generate():
27+
# enable static k/v cache for autoregressive models
28+
model.generation_config.cache_implementation = "static"
29+
30+
def benchmark(batch, min_new_tokens=None):
31+
# Load audio inputs
32+
audios = [audio["array"] for audio in batch["audio"]]
33+
minibatch_size = len(audios)
34+
35+
# START TIMING
36+
start_time = time.time()
37+
38+
# 1. Pre-Processing
39+
# 1.1 Pad audios to max batch size if using torch compile to prevent re-compilations
40+
padding_size = None
41+
if minibatch_size != args.batch_size and args.torch_compile:
42+
padding_size = args.batch_size - minibatch_size
43+
padding_audios = [audios[-1] for _ in range(padding_size)]
44+
audios.extend(padding_audios)
45+
46+
if not model.can_generate(): #or len(audios[0]) > processor.feature_extractor.n_samples:
47+
# 1.2 Either CTC pre-processing (normalize to mean 0, std 1), or long-form Whisper processing
48+
inputs = processor(
49+
audios,
50+
sampling_rate=16_000,
51+
truncation=False,
52+
padding="longest",
53+
return_tensors="pt",
54+
return_attention_mask=True,
55+
)
56+
else:
57+
# 1.3 Standard Whisper processing: pad audios to 30-seconds and converted to log-mel
58+
inputs = processor(audios, sampling_rate=16_000, return_tensors="pt", device=args.device)
59+
60+
inputs = inputs.to(args.device)
61+
inputs[model_input_name] = inputs[model_input_name].to(torch.float16)
62+
63+
# 2. Model Inference
64+
with sdpa_kernel(SDPBackend.MATH if args.torch_compile else SDPBackend.FLASH_ATTENTION):
65+
forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
66+
if model.can_generate():
67+
# 2.1 Auto-regressive generation for encoder-decoder models
68+
pred_ids = model.generate(**inputs, **gen_kwargs, min_new_tokens=min_new_tokens, forced_decoder_ids=forced_decoder_ids)
69+
else:
70+
# 2.2. Single forward pass for CTC
71+
with torch.no_grad():
72+
logits = model(**inputs).logits
73+
pred_ids = logits.argmax(-1)
74+
75+
# 3. Post-processing
76+
# 3.1 Strip padded ids from predictions
77+
if padding_size is not None:
78+
pred_ids = pred_ids[:-padding_size, ...]
79+
80+
# 3.2 Convert token ids to text transcription
81+
pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True)
82+
83+
# END TIMING
84+
runtime = time.time() - start_time
85+
86+
# normalize by minibatch size since we want the per-sample time
87+
batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
88+
89+
# normalize transcriptions with English normalizer
90+
batch["predictions"] = [data_utils.normalizer(pred) for pred in pred_text]
91+
batch["references"] = batch["norm_text"]
92+
return batch
93+
94+
if args.warmup_steps is not None:
95+
dataset = data_utils.load_data(args)
96+
dataset = data_utils.prepare_data(dataset)
97+
98+
num_warmup_samples = args.warmup_steps * args.batch_size
99+
if args.streaming:
100+
warmup_dataset = dataset.take(num_warmup_samples)
101+
else:
102+
warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
103+
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"min_new_tokens": args.max_new_tokens}))
104+
105+
for _ in tqdm(warmup_dataset, desc="Warming up..."):
106+
continue
107+
108+
dataset = data_utils.load_data(args)
109+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
110+
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
111+
if args.streaming:
112+
dataset = dataset.take(args.max_eval_samples)
113+
else:
114+
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
115+
dataset = data_utils.prepare_data(dataset)
116+
117+
dataset = dataset.map(
118+
benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
119+
)
120+
121+
all_results = {
122+
"audio_length_s": [],
123+
"transcription_time_s": [],
124+
"predictions": [],
125+
"references": [],
126+
}
127+
result_iter = iter(dataset)
128+
for result in tqdm(result_iter, desc="Samples..."):
129+
for key in all_results:
130+
all_results[key].append(result[key])
131+
132+
# Write manifest results (WER and RTFX)
133+
manifest_path = data_utils.write_manifest(
134+
all_results["references"],
135+
all_results["predictions"],
136+
args.model_id,
137+
args.dataset_path,
138+
args.dataset,
139+
args.split,
140+
audio_length=all_results["audio_length_s"],
141+
transcription_time=all_results["transcription_time_s"],
142+
)
143+
print("Results saved at path:", os.path.abspath(manifest_path))
144+
145+
wer = wer_metric.compute(
146+
references=all_results["references"], predictions=all_results["predictions"]
147+
)
148+
wer = round(100 * wer, 2)
149+
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
150+
print("WER:", wer, "%", "RTFx:", rtfx)
151+
152+
153+
if __name__ == "__main__":
154+
parser = argparse.ArgumentParser()
155+
156+
parser.add_argument(
157+
"--model_id",
158+
type=str,
159+
required=True,
160+
help="Model identifier. Should be loadable with 🤗 Transformers",
161+
)
162+
parser.add_argument(
163+
"--dataset_path",
164+
type=str,
165+
default="esb/datasets",
166+
help="Dataset path. By default, it is `esb/datasets`",
167+
)
168+
parser.add_argument(
169+
"--dataset",
170+
type=str,
171+
required=True,
172+
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
173+
"can be found at `https://huggingface.co/datasets/esb/datasets`",
174+
)
175+
parser.add_argument(
176+
"--split",
177+
type=str,
178+
default="test",
179+
help="Split of the dataset. *E.g.* `'validation'` for the dev split, or `'test'` for the test split.",
180+
)
181+
parser.add_argument(
182+
"--device",
183+
type=int,
184+
default=-1,
185+
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
186+
)
187+
parser.add_argument(
188+
"--batch_size",
189+
type=int,
190+
default=16,
191+
help="Number of samples to go through each streamed batch.",
192+
)
193+
parser.add_argument(
194+
"--max_eval_samples",
195+
type=int,
196+
default=None,
197+
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
198+
)
199+
parser.add_argument(
200+
"--no-streaming",
201+
dest="streaming",
202+
action="store_false",
203+
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
204+
)
205+
parser.add_argument(
206+
"--max_new_tokens",
207+
type=int,
208+
default=None,
209+
help="Maximum number of tokens to generate (for auto-regressive models).",
210+
)
211+
parser.add_argument(
212+
"--torch_compile",
213+
action="store_true",
214+
help="Whether to JIT compile the forward pass of the model.",
215+
)
216+
parser.add_argument(
217+
"--compile_mode",
218+
type=str,
219+
default="max-autotune",
220+
help="Mode for torch compiling model forward pass. Can be either 'default', 'reduce-overhead', 'max-autotune' or 'max-autotune-no-cudagraphs'.",
221+
)
222+
parser.add_argument(
223+
"--warmup_steps",
224+
type=int,
225+
default=10,
226+
help="Number of warm-up steps to run before launching the timed runs.",
227+
)
228+
args = parser.parse_args()
229+
parser.set_defaults(streaming=False)
230+
231+
main(args)

liteASR/run_liteasr.sh

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#!/bin/bash
2+
3+
export PYTHONPATH="..":$PYTHONPATH
4+
5+
MODEL_IDs=(
6+
"efficient-speech/lite-whisper-large-v3-acc"
7+
"efficient-speech/lite-whisper-large-v3"
8+
"efficient-speech/lite-whisper-large-v3-fast"
9+
"efficient-speech/lite-whisper-large-v3-turbo-acc"
10+
"efficient-speech/lite-whisper-large-v3-turbo"
11+
"efficient-speech/lite-whisper-large-v3-turbo-fast"
12+
)
13+
BATCH_SIZE=64
14+
15+
num_models=${#MODEL_IDs[@]}
16+
17+
for (( i=0; i<${num_models}; i++ ));
18+
do
19+
MODEL_ID=${MODEL_IDs[$i]}
20+
21+
python run_eval.py \
22+
--model_id=${MODEL_ID} \
23+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
24+
--dataset="voxpopuli" \
25+
--split="test" \
26+
--device=0 \
27+
--batch_size=${BATCH_SIZE} \
28+
--max_eval_samples=-1
29+
30+
python run_eval.py \
31+
--model_id=${MODEL_ID} \
32+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
33+
--dataset="ami" \
34+
--split="test" \
35+
--device=0 \
36+
--batch_size=${BATCH_SIZE} \
37+
--max_eval_samples=-1
38+
39+
python run_eval.py \
40+
--model_id=${MODEL_ID} \
41+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
42+
--dataset="earnings22" \
43+
--split="test" \
44+
--device=0 \
45+
--batch_size=${BATCH_SIZE} \
46+
--max_eval_samples=-1
47+
48+
python run_eval.py \
49+
--model_id=${MODEL_ID} \
50+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
51+
--dataset="gigaspeech" \
52+
--split="test" \
53+
--device=0 \
54+
--batch_size=${BATCH_SIZE} \
55+
--max_eval_samples=-1
56+
57+
python run_eval.py \
58+
--model_id=${MODEL_ID} \
59+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
60+
--dataset="librispeech" \
61+
--split="test.clean" \
62+
--device=0 \
63+
--batch_size=${BATCH_SIZE} \
64+
--max_eval_samples=-1
65+
66+
python run_eval.py \
67+
--model_id=${MODEL_ID} \
68+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
69+
--dataset="librispeech" \
70+
--split="test.other" \
71+
--device=0 \
72+
--batch_size=${BATCH_SIZE} \
73+
--max_eval_samples=-1
74+
75+
python run_eval.py \
76+
--model_id=${MODEL_ID} \
77+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
78+
--dataset="spgispeech" \
79+
--split="test" \
80+
--device=0 \
81+
--batch_size=${BATCH_SIZE} \
82+
--max_eval_samples=-1
83+
84+
python run_eval.py \
85+
--model_id=${MODEL_ID} \
86+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
87+
--dataset="tedlium" \
88+
--split="test" \
89+
--device=0 \
90+
--batch_size=${BATCH_SIZE} \
91+
--max_eval_samples=-1
92+
93+
# Evaluate results
94+
RUNDIR=`pwd` && \
95+
cd ../normalizer && \
96+
python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \
97+
cd $RUNDIR
98+
99+
done

0 commit comments

Comments
 (0)