Skip to content

Commit 1d8f058

Browse files
author
Eugene
committed
Adding Kyutai-STT on OpenASR leaderboard
Implementing via Moshi as an external library for inference speed.
1 parent d2167fb commit 1d8f058

File tree

3 files changed

+365
-0
lines changed

3 files changed

+365
-0
lines changed

kyutai/run_eval.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
import argparse
2+
import os
3+
import torch
4+
import evaluate
5+
from normalizer import data_utils
6+
import time
7+
from tqdm import tqdm
8+
import julius
9+
from moshi import models
10+
11+
wer_metric = evaluate.load("wer")
12+
torch.set_float32_matmul_precision("high")
13+
14+
15+
def load_model(model_path):
16+
17+
info = models.loaders.CheckpointInfo.from_hf_repo(model_path)
18+
19+
mimi = info.get_mimi(device="cuda")
20+
tokenizer = info.get_text_tokenizer()
21+
lm = info.get_moshi(
22+
device="cuda",
23+
dtype=torch.bfloat16,
24+
)
25+
lm_gen = models.LMGen(lm, temp=0, temp_text=0.0)
26+
27+
padding_token_id = info.raw_config.get("text_padding_token_id", 3)
28+
# Putting in some conservative defaults
29+
audio_silence_prefix_seconds = info.stt_config.get(
30+
"audio_silence_prefix_seconds", 1.0
31+
)
32+
audio_delay_seconds = info.stt_config.get("audio_delay_seconds", 5.0)
33+
34+
return (
35+
mimi,
36+
tokenizer,
37+
lm,
38+
lm_gen,
39+
padding_token_id,
40+
audio_silence_prefix_seconds,
41+
audio_delay_seconds,
42+
)
43+
44+
45+
@torch.inference_mode
46+
def get_padded_batch(
47+
audios, sample_rates, before_padding: float, after_padding: float, frame_size: int
48+
):
49+
sample_rate = 24_000
50+
51+
batch = []
52+
max_len = -1
53+
54+
for audio, sr in zip(audios, sample_rates):
55+
audio = julius.resample.resample_frac(audio, old_sr=sr, new_sr=sample_rate)
56+
audio = torch.nn.functional.pad(
57+
audio, (int(before_padding * sample_rate), int(after_padding * sample_rate))
58+
)
59+
max_len = max(max_len, audio.shape[-1])
60+
batch.append(audio)
61+
62+
target = max_len
63+
if target % frame_size != 0:
64+
target = target + (frame_size - max_len % frame_size)
65+
66+
batch = torch.stack(
67+
[
68+
torch.nn.functional.pad(audio, (0, target - audio.shape[-1]))
69+
for audio in batch
70+
]
71+
)
72+
return batch
73+
74+
75+
def main(args):
76+
(
77+
mimi,
78+
tokenizer,
79+
_lm,
80+
lm_gen,
81+
padding_token_id,
82+
audio_silence_prefix_seconds,
83+
audio_delay_seconds,
84+
) = load_model(args.model_id)
85+
86+
mimi_frame_size = mimi.frame_size
87+
88+
def benchmark(batch):
89+
# Load audio inputs
90+
audios = [torch.from_numpy(audio["array"]) for audio in batch["audio"]]
91+
sample_rates = [ex["sampling_rate"] for ex in batch["audio"]]
92+
93+
batch["audio_length_s"] = [
94+
len(audio) / batch["audio"][0]["sampling_rate"] for audio in audios
95+
]
96+
minibatch_size = len(audios)
97+
98+
# Start timing
99+
start_time = time.time()
100+
101+
padded_batch = get_padded_batch(
102+
audios,
103+
sample_rates,
104+
before_padding=audio_silence_prefix_seconds,
105+
after_padding=audio_delay_seconds,
106+
frame_size=mimi_frame_size,
107+
)
108+
padded_batch = padded_batch.to(args.device).float()
109+
110+
bsz = padded_batch.shape[0]
111+
112+
text_tokens_acc = []
113+
114+
with mimi.streaming(bsz), lm_gen.streaming(bsz):
115+
for offset in range(0, padded_batch.shape[-1], mimi.frame_size):
116+
audio_chunk = padded_batch[:, offset : offset + mimi.frame_size].cuda()
117+
tokens = mimi.encode(audio_chunk[:, None, :])
118+
text_tokens = lm_gen.step(tokens)
119+
text_tokens_acc.append(text_tokens)
120+
121+
pred_tokens = torch.concat(text_tokens_acc, axis=-1).squeeze(dim=1)
122+
pred_tokens = torch.unbind(pred_tokens, dim=0)
123+
124+
pred_text = [
125+
tokenizer.decode(t[t > padding_token_id].cpu().numpy().tolist())
126+
for t in pred_tokens
127+
]
128+
129+
# End timing
130+
runtime = time.time() - start_time
131+
132+
# normalize by minibatch size since we want the per-sample time
133+
batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
134+
135+
# normalize transcriptions with English normalizer
136+
batch["predictions"] = [data_utils.normalizer(pred) for pred in pred_text]
137+
batch["references"] = batch["norm_text"]
138+
return batch
139+
140+
if args.warmup_steps is not None:
141+
warmup_dataset = data_utils.load_data(args)
142+
warmup_dataset = data_utils.prepare_data(warmup_dataset)
143+
144+
num_warmup_samples = args.warmup_steps * args.batch_size
145+
if args.streaming:
146+
warmup_dataset = warmup_dataset.take(num_warmup_samples)
147+
else:
148+
warmup_dataset = warmup_dataset.select(
149+
range(min(num_warmup_samples, len(warmup_dataset)))
150+
)
151+
warmup_dataset = iter(
152+
warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True)
153+
)
154+
155+
for _ in tqdm(warmup_dataset, desc="Warming up..."):
156+
continue
157+
158+
dataset = data_utils.load_data(args)
159+
dataset = data_utils.prepare_data(dataset)
160+
161+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
162+
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
163+
if args.streaming:
164+
dataset = dataset.take(args.max_eval_samples)
165+
else:
166+
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
167+
168+
dataset = dataset.map(
169+
benchmark,
170+
batch_size=args.batch_size,
171+
batched=True,
172+
remove_columns=["audio"],
173+
)
174+
175+
all_results = {
176+
"audio_length_s": [],
177+
"transcription_time_s": [],
178+
"predictions": [],
179+
"references": [],
180+
}
181+
result_iter = iter(dataset)
182+
for result in tqdm(result_iter, desc="Samples..."):
183+
for key in all_results:
184+
all_results[key].append(result[key])
185+
186+
# Write manifest results (WER and RTFX)
187+
manifest_path = data_utils.write_manifest(
188+
all_results["references"],
189+
all_results["predictions"],
190+
args.model_id,
191+
args.dataset_path,
192+
args.dataset,
193+
args.split,
194+
audio_length=all_results["audio_length_s"],
195+
transcription_time=all_results["transcription_time_s"],
196+
)
197+
print("Results saved at path:", os.path.abspath(manifest_path))
198+
199+
wer = wer_metric.compute(
200+
references=all_results["references"], predictions=all_results["predictions"]
201+
)
202+
wer = round(100 * wer, 2)
203+
rtfx = round(
204+
sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2
205+
)
206+
print("WER:", wer, "%", "RTFx:", rtfx)
207+
208+
209+
if __name__ == "__main__":
210+
parser = argparse.ArgumentParser()
211+
212+
parser.add_argument(
213+
"--model_id",
214+
type=str,
215+
required=True,
216+
help="Model identifier. Should be loadable with 🤗 Transformers",
217+
)
218+
parser.add_argument(
219+
"--dataset_path",
220+
type=str,
221+
default="esb/datasets",
222+
help="Dataset path. By default, it is `esb/datasets`",
223+
)
224+
parser.add_argument(
225+
"--dataset",
226+
type=str,
227+
required=True,
228+
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
229+
"can be found at `https://huggingface.co/datasets/esb/datasets`",
230+
)
231+
parser.add_argument(
232+
"--split",
233+
type=str,
234+
default="test",
235+
help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.",
236+
)
237+
parser.add_argument(
238+
"--device",
239+
type=int,
240+
default=-1,
241+
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
242+
)
243+
parser.add_argument(
244+
"--batch_size",
245+
type=int,
246+
default=1,
247+
help="Number of samples to go through each streamed batch.",
248+
)
249+
parser.add_argument(
250+
"--max_eval_samples",
251+
type=int,
252+
default=None,
253+
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
254+
)
255+
parser.add_argument(
256+
"--no-streaming",
257+
dest="streaming",
258+
action="store_false",
259+
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
260+
)
261+
parser.add_argument(
262+
"--warmup_steps",
263+
type=int,
264+
default=10,
265+
help="Number of warm-up steps to run before launching the timed runs.",
266+
)
267+
args = parser.parse_args()
268+
parser.set_defaults(streaming=False)
269+
270+
main(args)

kyutai/run_kyutai.sh

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/bin/bash
2+
3+
export PYTHONPATH="..":$PYTHONPATH
4+
5+
MODEL_IDs=("kyutai/stt-2.6b-en")
6+
BATCH_SIZE=384
7+
8+
num_models=${#MODEL_IDs[@]}
9+
for (( i=0; i<${num_models}; i++ ));
10+
do
11+
MODEL_ID=${MODEL_IDs[$i]}
12+
13+
python run_eval.py \
14+
--model_id=${MODEL_ID} \
15+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
16+
--dataset="voxpopuli" \
17+
--split="test" \
18+
--device=0 \
19+
--batch_size=${BATCH_SIZE} \
20+
--max_eval_samples=-1
21+
22+
python run_eval.py \
23+
--model_id=${MODEL_ID} \
24+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
25+
--dataset="ami" \
26+
--split="test" \
27+
--device=0 \
28+
--batch_size=${BATCH_SIZE} \
29+
--max_eval_samples=-1
30+
31+
python run_eval.py \
32+
--model_id=${MODEL_ID} \
33+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
34+
--dataset="earnings22" \
35+
--split="test" \
36+
--device=0 \
37+
--batch_size=${BATCH_SIZE} \
38+
--max_eval_samples=-1
39+
40+
python run_eval.py \
41+
--model_id=${MODEL_ID} \
42+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
43+
--dataset="gigaspeech" \
44+
--split="test" \
45+
--device=0 \
46+
--batch_size=${BATCH_SIZE} \
47+
--max_eval_samples=-1
48+
49+
python run_eval.py \
50+
--model_id=${MODEL_ID} \
51+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
52+
--dataset="librispeech" \
53+
--split="test.clean" \
54+
--device=0 \
55+
--batch_size=${BATCH_SIZE} \
56+
--max_eval_samples=-1
57+
58+
python run_eval.py \
59+
--model_id=${MODEL_ID} \
60+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
61+
--dataset="librispeech" \
62+
--split="test.other" \
63+
--device=0 \
64+
--batch_size=${BATCH_SIZE} \
65+
--max_eval_samples=-1
66+
67+
python run_eval.py \
68+
--model_id=${MODEL_ID} \
69+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
70+
--dataset="spgispeech" \
71+
--split="test" \
72+
--device=0 \
73+
--batch_size=${BATCH_SIZE} \
74+
--max_eval_samples=-1
75+
76+
python run_eval.py \
77+
--model_id=${MODEL_ID} \
78+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
79+
--dataset="tedlium" \
80+
--split="test" \
81+
--device=0 \
82+
--batch_size=${BATCH_SIZE} \
83+
--max_eval_samples=-1
84+
85+
86+
# Evaluate results
87+
RUNDIR=`pwd` && \
88+
cd ../normalizer && \
89+
python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \
90+
cd $RUNDIR
91+
92+
done
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
torch>=2.6.0
2+
julius
3+
moshi>=0.2.6

0 commit comments

Comments
 (0)