Skip to content

Commit 9bdea31

Browse files
authored
Add UsefulSensors Moonshine benchmark (#43)
* Add UsefulSensors Moonshine benchmark Due to trainable in-model preprocessor and therefore lack of a spectrogram preprocessor, we have opted against wrapping the tokenizer as a processor. Further, we must make substantial changes compared with existing transformer models, so we decided to create a separate benchmark. * Add moonshine-specific requirements.txt. Adds the `einops` package which our HF hub repo requries.
1 parent bb39153 commit 9bdea31

File tree

3 files changed

+291
-0
lines changed

3 files changed

+291
-0
lines changed

moonshine/run_eval.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import argparse
2+
import os
3+
import torch
4+
from transformers import AutoConfig, AutoModelForSpeechSeq2Seq, AutoProcessor, PreTrainedTokenizerFast
5+
import evaluate
6+
from normalizer import data_utils
7+
import time
8+
from tqdm import tqdm
9+
import numpy as np
10+
11+
wer_metric = evaluate.load("wer")
12+
torch.set_float32_matmul_precision('high')
13+
14+
def main(args):
15+
config = AutoConfig.from_pretrained(args.model_id, trust_remote_code=True)
16+
model = AutoModelForSpeechSeq2Seq.from_pretrained(args.model_id, torch_dtype=torch.bfloat16, trust_remote_code=True).to(args.device)
17+
tokenizer = PreTrainedTokenizerFast.from_pretrained(args.model_id, trust_remote_code=True)
18+
19+
if args.torch_compile:
20+
model.forward = torch.compile(model.forward, mode=args.compile_mode, fullgraph=True)
21+
if model.can_generate():
22+
# enable static k/v cache for autoregressive models
23+
model.generation_config.cache_implementation = "static"
24+
25+
def benchmark(batch, min_new_tokens=None):
26+
# Load audio inputs
27+
audios = [audio["array"] for audio in batch["audio"]]
28+
minibatch_size = len(audios)
29+
30+
# START TIMING
31+
start_time = time.time()
32+
33+
np_arr = np.array(audios)
34+
input_tensor = torch.FloatTensor(np_arr)
35+
moonshine_min_input_size = 1024
36+
padding = moonshine_min_input_size - input_tensor.size()[1]
37+
if padding > 0:
38+
input_tensor = torch.nn.functional.pad(input_tensor, (0, padding))
39+
pred_ids = model(input_tensor.to(args.device).to(torch.bfloat16))
40+
41+
# 3.2 Convert token ids to text transcription
42+
pred_text = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
43+
44+
# END TIMING
45+
runtime = time.time() - start_time
46+
47+
# normalize by minibatch size since we want the per-sample time
48+
batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
49+
50+
# normalize transcriptions with English normalizer
51+
batch["predictions"] = [data_utils.normalizer(pred) for pred in pred_text]
52+
batch["references"] = batch["norm_text"]
53+
return batch
54+
55+
if args.warmup_steps is not None:
56+
dataset = data_utils.load_data(args)
57+
dataset = data_utils.prepare_data(dataset)
58+
59+
num_warmup_samples = args.warmup_steps * args.batch_size
60+
if args.streaming:
61+
warmup_dataset = dataset.take(num_warmup_samples)
62+
else:
63+
warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
64+
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"min_new_tokens": args.max_new_tokens}))
65+
66+
for _ in tqdm(warmup_dataset, desc="Warming up..."):
67+
continue
68+
69+
dataset = data_utils.load_data(args)
70+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
71+
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
72+
if args.streaming:
73+
dataset = dataset.take(args.max_eval_samples)
74+
else:
75+
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
76+
dataset = data_utils.prepare_data(dataset)
77+
78+
dataset = dataset.map(
79+
benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
80+
)
81+
82+
all_results = {
83+
"audio_length_s": [],
84+
"transcription_time_s": [],
85+
"predictions": [],
86+
"references": [],
87+
}
88+
result_iter = iter(dataset)
89+
for result in tqdm(result_iter, desc="Samples..."):
90+
for key in all_results:
91+
all_results[key].append(result[key])
92+
93+
# Write manifest results (WER and RTFX)
94+
manifest_path = data_utils.write_manifest(
95+
all_results["references"],
96+
all_results["predictions"],
97+
args.model_id,
98+
args.dataset_path,
99+
args.dataset,
100+
args.split,
101+
audio_length=all_results["audio_length_s"],
102+
transcription_time=all_results["transcription_time_s"],
103+
)
104+
print("Results saved at path:", os.path.abspath(manifest_path))
105+
106+
wer = wer_metric.compute(
107+
references=all_results["references"], predictions=all_results["predictions"]
108+
)
109+
wer = round(100 * wer, 2)
110+
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
111+
print("WER:", wer, "%", "RTFx:", rtfx)
112+
113+
114+
if __name__ == "__main__":
115+
parser = argparse.ArgumentParser()
116+
117+
parser.add_argument(
118+
"--model_id",
119+
type=str,
120+
required=True,
121+
help="Model identifier. Should be loadable with 🤗 Transformers",
122+
)
123+
parser.add_argument(
124+
"--dataset_path",
125+
type=str,
126+
default="esb/datasets",
127+
help="Dataset path. By default, it is `esb/datasets`",
128+
)
129+
parser.add_argument(
130+
"--dataset",
131+
type=str,
132+
required=True,
133+
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
134+
"can be found at `https://huggingface.co/datasets/esb/datasets`",
135+
)
136+
parser.add_argument(
137+
"--split",
138+
type=str,
139+
default="test",
140+
help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.",
141+
)
142+
parser.add_argument(
143+
"--device",
144+
type=int,
145+
default=-1,
146+
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
147+
)
148+
parser.add_argument(
149+
"--batch_size",
150+
type=int,
151+
default=16,
152+
help="Number of samples to go through each streamed batch.",
153+
)
154+
parser.add_argument(
155+
"--max_eval_samples",
156+
type=int,
157+
default=None,
158+
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
159+
)
160+
parser.add_argument(
161+
"--no-streaming",
162+
dest="streaming",
163+
action="store_false",
164+
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
165+
)
166+
parser.add_argument(
167+
"--max_new_tokens",
168+
type=int,
169+
default=None,
170+
help="Maximum number of tokens to generate (for auto-regressive models).",
171+
)
172+
parser.add_argument(
173+
"--torch_compile",
174+
action="store_true",
175+
help="Whether to JIT compile the forward pass of the model.",
176+
)
177+
parser.add_argument(
178+
"--compile_mode",
179+
type=str,
180+
default="max-autotune",
181+
help="Mode for torch compiling model forward pass. Can be either 'default', 'reduce-overhead', 'max-autotune' or 'max-autotune-no-cudagraphs'.",
182+
)
183+
parser.add_argument(
184+
"--warmup_steps",
185+
type=int,
186+
default=10,
187+
help="Number of warm-up steps to run before launching the timed runs.",
188+
)
189+
args = parser.parse_args()
190+
parser.set_defaults(streaming=False)
191+
192+
main(args)

moonshine/run_moonshine.sh

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/bin/bash
2+
3+
export PYTHONPATH="..":$PYTHONPATH
4+
5+
MODEL_IDs=("usefulsensors/moonshine-base" "usefulsensors/moonshine-tiny")
6+
BATCH_SIZE=1
7+
8+
num_models=${#MODEL_IDs[@]}
9+
10+
for (( i=0; i<${num_models}; i++ ));
11+
do
12+
MODEL_ID=${MODEL_IDs[$i]}
13+
14+
python run_eval.py \
15+
--model_id=${MODEL_ID} \
16+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
17+
--dataset="voxpopuli" \
18+
--split="test" \
19+
--device=0 \
20+
--batch_size=${BATCH_SIZE} \
21+
--max_eval_samples=-1
22+
23+
python run_eval.py \
24+
--model_id=${MODEL_ID} \
25+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
26+
--dataset="ami" \
27+
--split="test" \
28+
--device=0 \
29+
--batch_size=${BATCH_SIZE} \
30+
--max_eval_samples=-1
31+
32+
python run_eval.py \
33+
--model_id=${MODEL_ID} \
34+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
35+
--dataset="earnings22" \
36+
--split="test" \
37+
--device=0 \
38+
--batch_size=${BATCH_SIZE} \
39+
--max_eval_samples=-1
40+
41+
python run_eval.py \
42+
--model_id=${MODEL_ID} \
43+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
44+
--dataset="gigaspeech" \
45+
--split="test" \
46+
--device=0 \
47+
--batch_size=${BATCH_SIZE} \
48+
--max_eval_samples=-1
49+
50+
python run_eval.py \
51+
--model_id=${MODEL_ID} \
52+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
53+
--dataset="librispeech" \
54+
--split="test.clean" \
55+
--device=0 \
56+
--batch_size=${BATCH_SIZE} \
57+
--max_eval_samples=-1
58+
59+
python run_eval.py \
60+
--model_id=${MODEL_ID} \
61+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
62+
--dataset="librispeech" \
63+
--split="test.other" \
64+
--device=0 \
65+
--batch_size=${BATCH_SIZE} \
66+
--max_eval_samples=-1
67+
68+
python run_eval.py \
69+
--model_id=${MODEL_ID} \
70+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
71+
--dataset="spgispeech" \
72+
--split="test" \
73+
--device=0 \
74+
--batch_size=${BATCH_SIZE} \
75+
--max_eval_samples=-1
76+
77+
python run_eval.py \
78+
--model_id=${MODEL_ID} \
79+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
80+
--dataset="tedlium" \
81+
--split="test" \
82+
--device=0 \
83+
--batch_size=${BATCH_SIZE} \
84+
--max_eval_samples=-1
85+
86+
# Evaluate results
87+
RUNDIR=`pwd` && \
88+
cd ../normalizer && \
89+
python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \
90+
cd $RUNDIR
91+
92+
done
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
torch
2+
transformers
3+
evaluate
4+
datasets
5+
librosa
6+
jiwer
7+
einops

0 commit comments

Comments
 (0)