Skip to content

Commit ca92caf

Browse files
authored
Add support for ctranslate2 whisper models (#10)
* add ctranslate2 eval * correct calc_rtf name for all files * add missing regex dependency * fix cuda index * add DEVICE_INDEX as variable * add hub login * remove regex included in transformers * fix inference * use_auth_token -> token * add compute type for evaluation * scripts corrections * fix typo * update shell script url
1 parent d9c2518 commit ca92caf

File tree

8 files changed

+294
-4
lines changed

8 files changed

+294
-4
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Each library has its own set of requirements. We recommend using a clean conda e
1010
2) Install PyTorch by following the instructions here: https://pytorch.org/get-started/locally/
1111
3) Install the common requirements for all library by running `pip install -r requirements/requirements.txt`.
1212
4) Install the requirements for each library you wish to evalaute by running `pip install -r requirements/requirements_<library_name>.txt`.
13-
13+
5) Connect your Hugging Face account by running `huggingface-cli login`.
1414

1515
# Evaluate a model
1616

@@ -32,7 +32,7 @@ To add a new library for evalution in this benchmark, please follow the steps be
3232
4) Create one bash file per model type following the convesion `run_<model_type>.sh`.
3333
- The bash script should follow the same steps as other libraries.
3434
- Different model sizes of the same type should share the script. For example `Wav2Vec` and `Wav2Vec2` would be two separate scripts, but different size of `Wav2Vec2` would be part of the same script.
35-
5) (Optional) You could also add a `compute_rtf.py` script for your library to evaluate the Real Time Factor of the model.
35+
5) (Optional) You could also add a `calc_rtf.py` script for your library to evaluate the Real Time Factor of the model.
3636
6) Submit a PR for your changes.
3737

3838
# Add a new model

ctranslate2/calc_rtf.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import time
2+
import librosa
3+
4+
from faster_whisper import WhisperModel
5+
6+
device = "cuda"
7+
device_index = 0
8+
9+
models = [
10+
"guillaumekln/faster-whisper-tiny.en",
11+
"guillaumekln/faster-whisper-small.en",
12+
"guillaumekln/faster-whisper-base.en",
13+
"guillaumekln/faster-whisper-medium.en",
14+
"guillaumekln/faster-whisper-large-v1",
15+
"guillaumekln/faster-whisper-large-v2",
16+
]
17+
18+
n_batches = 3
19+
warmup_batches = 5
20+
21+
audio_file = "4469669.mp3"
22+
max_len = 600 # 10 minutes
23+
24+
25+
def pre_process_audio(audio_file, sr, max_len):
26+
_, _sr = librosa.load(audio_file, sr=sr)
27+
audio_len = int(max_len * _sr)
28+
audio_arr = _[:audio_len]
29+
return {"raw": audio_arr, "sampling_rate": _sr}, audio_len
30+
31+
32+
audio_dict, audio_len = pre_process_audio(audio_file, 16000, max_len)
33+
34+
rtfs = []
35+
36+
for model in models[:1]:
37+
asr_model = WhisperModel(
38+
model_size_or_path=model,
39+
device=device,
40+
device_index=device_index,
41+
compute_type="float16",
42+
)
43+
44+
for i in range(3):
45+
print(f"outer_loop -> {i}")
46+
total_time = 0.0
47+
for _ in range(n_batches + warmup_batches):
48+
print(f"batch_num -> {_}")
49+
start = time.time()
50+
segments, _ = asr_model.transcribe(audio_dict["raw"], language="en")
51+
_ = [segment._asdict() for segment in segments] # Iterate over segments to run inference
52+
end = time.time()
53+
if _ >= warmup_batches:
54+
total_time += end - start
55+
56+
rtf = (total_time / n_batches) / (audio_len / 16000)
57+
rtfs.append(rtf)
58+
59+
print(f"all RTFs: {model}: {rtfs}")
60+
rtf_val = sum(rtfs) / len(rtfs)
61+
print(f"avg. RTF: {model}: {rtf_val}")

ctranslate2/run_eval.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""Run evaluation for ctranslate2 whisper models."""""
2+
import argparse
3+
import os
4+
5+
import evaluate
6+
from faster_whisper import WhisperModel
7+
from tqdm import tqdm
8+
9+
from normalizer import data_utils
10+
11+
wer_metric = evaluate.load("wer")
12+
13+
14+
def dataset_iterator(dataset) -> dict:
15+
"""
16+
Iterate over the dataset and yield a dictionary with the audio and reference text.
17+
18+
Args:
19+
dataset: dataset to iterate over
20+
21+
Returns:
22+
dictionary: {"audio": audio, "reference": reference}
23+
"""
24+
for item in dataset:
25+
yield {**item["audio"], "reference": item["norm_text"]}
26+
27+
28+
def main(args) -> None:
29+
"""Main function to run evaluation on a dataset."""
30+
asr_model = WhisperModel(
31+
model_size_or_path=args.model_id,
32+
compute_type="float16",
33+
device="cuda",
34+
device_index=args.device
35+
)
36+
37+
dataset = data_utils.load_data(args)
38+
39+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
40+
print(f"Subsampling dataset to first {args.max_eval_samples} samples !")
41+
dataset = dataset.take(args.max_eval_samples)
42+
43+
dataset = data_utils.prepare_data(dataset)
44+
45+
predictions = []
46+
references = []
47+
48+
# Run inference
49+
for batch in tqdm(dataset_iterator(dataset), desc=f"Evaluating {args.model_id}"):
50+
segments, _ = asr_model.transcribe(batch["array"], language="en")
51+
outputs = [segment._asdict() for segment in segments]
52+
predictions.extend(
53+
data_utils.normalizer(
54+
"".join([segment["text"] for segment in outputs])
55+
).strip()
56+
)
57+
references.extend(batch["reference"][0])
58+
59+
# Write manifest results
60+
manifest_path = data_utils.write_manifest(
61+
references, predictions, args.model_id, args.dataset_path, args.dataset, args.split
62+
)
63+
print("Results saved at path:", os.path.abspath(manifest_path))
64+
65+
wer = wer_metric.compute(references=references, predictions=predictions)
66+
wer = round(100 * wer, 2)
67+
68+
print("WER:", wer, "%")
69+
70+
71+
if __name__ == "__main__":
72+
parser = argparse.ArgumentParser()
73+
74+
parser.add_argument(
75+
"--model_id",
76+
type=str,
77+
required=True,
78+
help="Model identifier. Should be loadable with 🤗 Transformers",
79+
)
80+
parser.add_argument(
81+
'--dataset_path', type=str, default='esb/datasets', help='Dataset path. By default, it is `esb/datasets`'
82+
)
83+
parser.add_argument(
84+
"--dataset",
85+
type=str,
86+
required=True,
87+
help="Dataset name. *E.g.* `'librispeech_asr` for the LibriSpeech ASR dataset, or `'common_voice'` for Common Voice. The full list of dataset names "
88+
"can be found at `https://huggingface.co/datasets/esb/datasets`"
89+
)
90+
parser.add_argument(
91+
"--split",
92+
type=str,
93+
default="test",
94+
help="Split of the dataset. *E.g.* `'validation`' for the dev split, or `'test'` for the test split.",
95+
)
96+
parser.add_argument(
97+
"--device",
98+
type=int,
99+
default=-1,
100+
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
101+
)
102+
parser.add_argument(
103+
"--batch_size",
104+
type=int,
105+
default=16,
106+
help="Number of samples to go through each streamed batch.",
107+
)
108+
parser.add_argument(
109+
"--max_eval_samples",
110+
type=int,
111+
default=None,
112+
help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.",
113+
)
114+
parser.add_argument(
115+
"--no-streaming",
116+
dest='streaming',
117+
action="store_false",
118+
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
119+
)
120+
args = parser.parse_args()
121+
parser.set_defaults(streaming=False)
122+
123+
main(args)

ctranslate2/run_whisper.sh

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#!/bin/bash
2+
3+
export PYTHONPATH="..":$PYTHONPATH
4+
5+
MODEL_IDs=("guillaumekln/faster-whisper-tiny.en" "guillaumekln/faster-whisper-small.en" "guillaumekln/faster-whisper-base.en" "guillaumekln/faster-whisper-medium.en" "guillaumekln/faster-whisper-large-v1" "guillaumekln/faster-whisper-large-v2")
6+
BATCH_SIZE=1
7+
DEVICE_INDEX=0
8+
9+
num_models=${#MODEL_IDs[@]}
10+
11+
for (( i=0; i<${num_models}; i++ ));
12+
do
13+
MODEL_ID=${MODEL_IDs[$i]}
14+
15+
python run_eval.py \
16+
--model_id=${MODEL_ID} \
17+
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
18+
--dataset="ami" \
19+
--split="test" \
20+
--device=${DEVICE_INDEX} \
21+
--batch_size=${BATCH_SIZE} \
22+
--max_eval_samples=-1
23+
24+
python run_eval.py \
25+
--model_id=${MODEL_ID} \
26+
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
27+
--dataset="earnings22" \
28+
--split="test" \
29+
--device=${DEVICE_INDEX} \
30+
--batch_size=${BATCH_SIZE} \
31+
--max_eval_samples=-1
32+
33+
python run_eval.py \
34+
--model_id=${MODEL_ID} \
35+
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
36+
--dataset="gigaspeech" \
37+
--split="test" \
38+
--device=${DEVICE_INDEX} \
39+
--batch_size=${BATCH_SIZE} \
40+
--max_eval_samples=-1
41+
42+
python run_eval.py \
43+
--model_id=${MODEL_ID} \
44+
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
45+
--dataset="librispeech" \
46+
--split="test.clean" \
47+
--device=${DEVICE_INDEX} \
48+
--batch_size=${BATCH_SIZE} \
49+
--max_eval_samples=-1
50+
51+
python run_eval.py \
52+
--model_id=${MODEL_ID} \
53+
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
54+
--dataset="librispeech" \
55+
--split="test.other" \
56+
--device=${DEVICE_INDEX} \
57+
--batch_size=${BATCH_SIZE} \
58+
--max_eval_samples=-1
59+
60+
python run_eval.py \
61+
--model_id=${MODEL_ID} \
62+
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
63+
--dataset="spgispeech" \
64+
--split="test" \
65+
--device=${DEVICE_INDEX} \
66+
--batch_size=${BATCH_SIZE} \
67+
--max_eval_samples=-1
68+
69+
python run_eval.py \
70+
--model_id=${MODEL_ID} \
71+
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
72+
--dataset="tedlium" \
73+
--split="test" \
74+
--device=${DEVICE_INDEX} \
75+
--batch_size=${BATCH_SIZE} \
76+
--max_eval_samples=-1
77+
78+
python run_eval.py \
79+
--model_id=${MODEL_ID} \
80+
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
81+
--dataset="voxpopuli" \
82+
--split="test" \
83+
--device=${DEVICE_INDEX} \
84+
--batch_size=${BATCH_SIZE} \
85+
--max_eval_samples=-1
86+
87+
python run_eval.py \
88+
--model_id=${MODEL_ID} \
89+
--dataset_path="https://huggingface.co/datasets/hf-audio/esb-datasets-test-only" \
90+
--dataset="common_voice" \
91+
--split="test" \
92+
--device=${DEVICE_INDEX} \
93+
--batch_size=${BATCH_SIZE} \
94+
--max_eval_samples=-1
95+
96+
# Evaluate results
97+
RUNDIR=`pwd` && \
98+
cd ../normalizer && \
99+
python -c "import eval_utils; eval_utils.score_results('${RUNDIR}/results', '${MODEL_ID}')" && \
100+
cd $RUNDIR
101+
102+
done

normalizer/data_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def load_data(args):
4343
args.dataset,
4444
split=args.split,
4545
streaming=args.streaming,
46-
use_auth_token=True,
46+
token=True,
4747
)
4848

4949
return dataset

requirements/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,3 @@ evaluate
44
datasets
55
librosa
66
jiwer
7-
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
datasets
2+
evaluate
3+
faster-whisper>=0.8.0
4+
jiwer
5+
librosa

0 commit comments

Comments
 (0)