Skip to content

Commit 11543d8

Browse files
Merge pull request #34 from sanchit-gandhi/update-other-libs
Propagate RTFx updates to other libs
2 parents c1c6dc1 + 1c5b1ba commit 11543d8

11 files changed

+318
-341
lines changed

ctranslate2/calc_rtf.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

ctranslate2/run_eval.py

Lines changed: 60 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Run evaluation for ctranslate2 whisper models."""""
22
import argparse
33
import os
4+
import time
45

56
import evaluate
67
from faster_whisper import WhisperModel
@@ -11,20 +12,6 @@
1112
wer_metric = evaluate.load("wer")
1213

1314

14-
def dataset_iterator(dataset) -> dict:
15-
"""
16-
Iterate over the dataset and yield a dictionary with the audio and reference text.
17-
18-
Args:
19-
dataset: dataset to iterate over
20-
21-
Returns:
22-
dictionary: {"audio": audio, "reference": reference}
23-
"""
24-
for item in dataset:
25-
yield {**item["audio"], "reference": item["norm_text"]}
26-
27-
2815
def main(args) -> None:
2916
"""Main function to run evaluation on a dataset."""
3017
asr_model = WhisperModel(
@@ -34,38 +21,69 @@ def main(args) -> None:
3421
device_index=args.device
3522
)
3623

37-
dataset = data_utils.load_data(args)
24+
def benchmark(batch):
25+
start_time = time.time()
26+
segments, _ = asr_model.transcribe(batch["audio"]["array"], language="en")
27+
outputs = [segment._asdict() for segment in segments]
28+
batch["transcription_time_s"] = time.time() - start_time
29+
batch["predictions"] = data_utils.normalizer("".join([segment["text"] for segment in outputs])).strip()
30+
batch["references"] = batch["norm_text"]
31+
return batch
3832

39-
if args.max_eval_samples is not None and args.max_eval_samples > 0:
40-
print(f"Subsampling dataset to first {args.max_eval_samples} samples !")
41-
dataset = dataset.take(args.max_eval_samples)
33+
if args.warmup_steps is not None:
34+
dataset = data_utils.load_data(args)
35+
dataset = data_utils.prepare_data(dataset)
4236

43-
dataset = data_utils.prepare_data(dataset)
37+
if args.streaming:
38+
warmup_dataset = dataset.take(args.warmup_steps)
39+
else:
40+
warmup_dataset = dataset.select(range(min(args.warmup_steps, len(dataset))))
41+
warmup_dataset = iter(warmup_dataset.map(benchmark, remove_columns=["audio"]))
4442

45-
predictions = []
46-
references = []
43+
for _ in tqdm(warmup_dataset, desc="Warming up..."):
44+
continue
4745

48-
# Run inference
49-
for batch in tqdm(dataset_iterator(dataset), desc=f"Evaluating {args.model_id}"):
50-
segments, _ = asr_model.transcribe(batch["array"], language="en")
51-
outputs = [segment._asdict() for segment in segments]
52-
transcription = data_utils.normalizer(
53-
"".join([segment["text"] for segment in outputs])
54-
).strip()
46+
dataset = data_utils.load_data(args)
47+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
48+
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
49+
if args.streaming:
50+
dataset = dataset.take(args.max_eval_samples)
51+
else:
52+
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
53+
dataset = data_utils.prepare_data(dataset)
54+
55+
dataset = dataset.map(benchmark, remove_columns=["audio"])
5556

56-
predictions.append(transcription)
57-
references.append(batch["reference"])
57+
all_results = {
58+
"audio_length_s": [],
59+
"transcription_time_s": [],
60+
"predictions": [],
61+
"references": [],
62+
}
63+
result_iter = iter(dataset)
64+
for result in tqdm(result_iter, desc="Samples..."):
65+
for key in all_results:
66+
all_results[key].append(result[key])
5867

59-
# Write manifest results
68+
# Write manifest results (WER and RTFX)
6069
manifest_path = data_utils.write_manifest(
61-
references, predictions, args.model_id, args.dataset_path, args.dataset, args.split
70+
all_results["references"],
71+
all_results["predictions"],
72+
args.model_id,
73+
args.dataset_path,
74+
args.dataset,
75+
args.split,
76+
audio_length=all_results["audio_length_s"],
77+
transcription_time=all_results["transcription_time_s"],
6278
)
6379
print("Results saved at path:", os.path.abspath(manifest_path))
6480

65-
wer = wer_metric.compute(references=references, predictions=predictions)
81+
wer = wer_metric.compute(
82+
references=all_results["references"], predictions=all_results["predictions"]
83+
)
6684
wer = round(100 * wer, 2)
67-
68-
print("WER:", wer, "%")
85+
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
86+
print("WER:", wer, "%", "RTFx:", rtfx)
6987

7088

7189
if __name__ == "__main__":
@@ -75,7 +93,7 @@ def main(args) -> None:
7593
"--model_id",
7694
type=str,
7795
required=True,
78-
help="Model identifier. Should be loadable with 🤗 Transformers",
96+
help="Model identifier. Should be loadable with faster-whisper",
7997
)
8098
parser.add_argument(
8199
'--dataset_path', type=str, default='esb/datasets', help='Dataset path. By default, it is `esb/datasets`'
@@ -99,12 +117,6 @@ def main(args) -> None:
99117
default=-1,
100118
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
101119
)
102-
parser.add_argument(
103-
"--batch_size",
104-
type=int,
105-
default=16,
106-
help="Number of samples to go through each streamed batch.",
107-
)
108120
parser.add_argument(
109121
"--max_eval_samples",
110122
type=int,
@@ -117,6 +129,12 @@ def main(args) -> None:
117129
action="store_false",
118130
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
119131
)
132+
parser.add_argument(
133+
"--warmup_steps",
134+
type=int,
135+
default=5,
136+
help="Number of warm-up steps to run before launching the timed runs.",
137+
)
120138
args = parser.parse_args()
121139
parser.set_defaults(streaming=False)
122140

ctranslate2/run_whisper.sh

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
export PYTHONPATH="..":$PYTHONPATH
44

55
MODEL_IDs=("tiny.en" "small.en" "base.en" "medium.en" "large-v1" "large-v2" "large-v3")
6-
BATCH_SIZE=1
76
DEVICE_INDEX=0
87

98
num_models=${#MODEL_IDs[@]}
@@ -18,7 +17,6 @@ do
1817
--dataset="ami" \
1918
--split="test" \
2019
--device=${DEVICE_INDEX} \
21-
--batch_size=${BATCH_SIZE} \
2220
--max_eval_samples=-1
2321

2422
python run_eval.py \
@@ -27,7 +25,6 @@ do
2725
--dataset="earnings22" \
2826
--split="test" \
2927
--device=${DEVICE_INDEX} \
30-
--batch_size=${BATCH_SIZE} \
3128
--max_eval_samples=-1
3229

3330
python run_eval.py \
@@ -36,7 +33,6 @@ do
3633
--dataset="gigaspeech" \
3734
--split="test" \
3835
--device=${DEVICE_INDEX} \
39-
--batch_size=${BATCH_SIZE} \
4036
--max_eval_samples=-1
4137

4238
python run_eval.py \
@@ -45,7 +41,6 @@ do
4541
--dataset="librispeech" \
4642
--split="test.clean" \
4743
--device=${DEVICE_INDEX} \
48-
--batch_size=${BATCH_SIZE} \
4944
--max_eval_samples=-1
5045

5146
python run_eval.py \
@@ -54,7 +49,6 @@ do
5449
--dataset="librispeech" \
5550
--split="test.other" \
5651
--device=${DEVICE_INDEX} \
57-
--batch_size=${BATCH_SIZE} \
5852
--max_eval_samples=-1
5953

6054
python run_eval.py \
@@ -63,7 +57,6 @@ do
6357
--dataset="spgispeech" \
6458
--split="test" \
6559
--device=${DEVICE_INDEX} \
66-
--batch_size=${BATCH_SIZE} \
6760
--max_eval_samples=-1
6861

6962
python run_eval.py \
@@ -72,7 +65,6 @@ do
7265
--dataset="tedlium" \
7366
--split="test" \
7467
--device=${DEVICE_INDEX} \
75-
--batch_size=${BATCH_SIZE} \
7668
--max_eval_samples=-1
7769

7870
python run_eval.py \
@@ -81,7 +73,6 @@ do
8173
--dataset="voxpopuli" \
8274
--split="test" \
8375
--device=${DEVICE_INDEX} \
84-
--batch_size=${BATCH_SIZE} \
8576
--max_eval_samples=-1
8677

8778
# Evaluate results

speechbrain/run_conformer.sh

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,24 @@ export PYTHONPATH="..":$PYTHONPATH
55
SOURCE="speechbrain/asr-conformer-transformerlm-librispeech"
66

77
python run_eval.py \
8-
--source=$SOURCE \
9-
--speechbrain_pretrained_class_name="EncoderDecoderASR" \
10-
--dataset_path="librispeech_asr" \
11-
--dataset="clean" \
12-
--split="test" \
13-
--device=0 \
14-
--batch_size=4 \
15-
--max_eval_samples=-1
8+
--source=$SOURCE \
9+
--speechbrain_pretrained_class_name="EncoderDecoderASR" \
10+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
11+
--dataset="librispeech" \
12+
--split="test.clean" \
13+
--device=0 \
14+
--batch_size=4 \
15+
--max_eval_samples=-1
1616

1717
python run_eval.py \
18-
--source=$SOURCE \
19-
--speechbrain_pretrained_class_name="EncoderDecoderASR" \
20-
--dataset_path="librispeech_asr" \
21-
--dataset="other" \
22-
--split="test" \
23-
--device=0 \
24-
--batch_size=4 \
25-
--max_eval_samples=-1
18+
--source=$SOURCE \
19+
--speechbrain_pretrained_class_name="EncoderDecoderASR" \
20+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
21+
--dataset="librispeech" \
22+
--split="test.other" \
23+
--device=0 \
24+
--batch_size=4 \
25+
--max_eval_samples=-1
2626

2727
# Evaluate results
2828
RUNDIR=`pwd` && \

speechbrain/run_conformersmall.sh

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,24 @@ export PYTHONPATH="..":$PYTHONPATH
55
SOURCE="speechbrain/asr-conformersmall-transformerlm-librispeech"
66

77
python run_eval.py \
8-
--source=$SOURCE \
9-
--speechbrain_pretrained_class_name="EncoderDecoderASR" \
10-
--dataset_path="librispeech_asr" \
11-
--dataset="clean" \
12-
--split="test" \
13-
--device=0 \
14-
--batch_size=4 \
15-
--max_eval_samples=-1
8+
--source=$SOURCE \
9+
--speechbrain_pretrained_class_name="EncoderDecoderASR" \
10+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
11+
--dataset="librispeech" \
12+
--split="test.clean" \
13+
--device=0 \
14+
--batch_size=4 \
15+
--max_eval_samples=-1
1616

1717
python run_eval.py \
18-
--source=$SOURCE \
19-
--speechbrain_pretrained_class_name="EncoderDecoderASR" \
20-
--dataset_path="librispeech_asr" \
21-
--dataset="other" \
22-
--split="test" \
23-
--device=0 \
24-
--batch_size=4 \
25-
--max_eval_samples=-1
18+
--source=$SOURCE \
19+
--speechbrain_pretrained_class_name="EncoderDecoderASR" \
20+
--dataset_path="hf-audio/esb-datasets-test-only-sorted" \
21+
--dataset="librispeech" \
22+
--split="test.other" \
23+
--device=0 \
24+
--batch_size=4 \
25+
--max_eval_samples=-1
2626

2727
# Evaluate results
2828
RUNDIR=`pwd` && \

0 commit comments

Comments
 (0)