Skip to content

Commit b994241

Browse files
author
sanchit-gandhi
committed
update speechbrain script
1 parent 347784d commit b994241

File tree

1 file changed

+82
-113
lines changed

1 file changed

+82
-113
lines changed

speechbrain/run_eval.py

Lines changed: 82 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22
33
Authors
44
* Adel Moumen 2023 <[email protected]>
5+
* Sanchit Gandhi 2024 <[email protected]>
56
"""
67
import argparse
8+
import time
79

810
import evaluate
911
from normalizer import data_utils
1012
from tqdm import tqdm
1113
import torch
12-
import speechbrain.pretrained as pretrained
14+
import speechbrain.inference.ASR as ASR
1315
from speechbrain.utils.data_utils import batch_pad_right
14-
from datasets import Dataset
15-
from typing import List, Union
16-
import os
16+
import os
1717

1818
def get_model(
1919
speechbrain_repository: str,
@@ -61,7 +61,7 @@ def get_model(
6161
}
6262

6363
try:
64-
model_class = getattr(pretrained, speechbrain_pretrained_class_name)
64+
model_class = getattr(ASR, speechbrain_pretrained_class_name)
6565
except AttributeError:
6666
raise AttributeError(
6767
f"SpeechBrain Pretrained class: {speechbrain_pretrained_class_name} not found in pretrained.py"
@@ -70,137 +70,100 @@ def get_model(
7070
return model_class.from_hparams(**kwargs)
7171

7272

73-
def dataset_iterator(dataset: Dataset):
74-
"""Iterate over the dataset and yield the audio and reference text.
75-
76-
Arguments
77-
---------
78-
dataset : Dataset
79-
The dataset to iterate over.
80-
81-
Yields
82-
------
83-
dict
84-
A dictionary containing the audio and reference text.
85-
"""
86-
for i, item in enumerate(dataset):
87-
yield {
88-
**item["audio"],
89-
"reference": item["norm_text"],
90-
"audio_filename": f"file_{i}",
91-
"sample_rate": 16_000,
92-
"sample_id": i,
93-
}
73+
def main(args):
74+
"""Run the evaluation script."""
75+
if args.device == -1:
76+
device = "cpu"
77+
else:
78+
device = f"cuda:{args.device}"
9479

80+
model = get_model(
81+
args.source, args.speechbrain_pretrained_class_name, device=device
82+
)
9583

96-
def evaluate_batch(model, buffer: List, predictions: List, device: str) -> None:
97-
"""Evaluate a batch of audio samples.
84+
def benchmark(batch):
85+
# Load audio inputs
86+
audios = [torch.from_numpy(sample["array"]) for sample in batch["audio"]]
87+
minibatch_size = len(audios)
9888

99-
Arguments
100-
---------
101-
model : Pretrained
102-
The SpeechBrain pretrained model.
103-
buffer : List
104-
A list of audio samples.
105-
predictions : List
106-
A list of predictions.
107-
device : str
108-
The device to run the model on.
109-
"""
110-
wavs = [torch.from_numpy(sample["array"]) for sample in buffer]
111-
wavs, wav_lens = batch_pad_right(wavs)
112-
wavs = wavs.to(device)
113-
wav_lens = wav_lens.to(device)
114-
predicted_words, _ = model.transcribe_batch(wavs, wav_lens)
89+
# START TIMING
90+
start_time = time.time()
11591

116-
for result in predicted_words:
117-
result = data_utils.normalizer(result)
118-
predictions.append(result)
119-
buffer.clear()
92+
audios, audio_lens = batch_pad_right(audios)
93+
audios = audios.to(device)
94+
audio_lens = audio_lens.to(device)
95+
predictions, _ = model.transcribe_batch(audios, audio_lens)
12096

97+
# END TIMING
98+
runtime = time.time() - start_time
12199

122-
def evaluate_dataset(
123-
model, dataset: Dataset, device: str, batch_size: int, verbose: bool = True
124-
) -> Union[List, List]:
125-
"""Evaluate a dataset the SpeechBrain pretrained model.
100+
# normalize by minibatch size since we want the per-sample time
101+
batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
126102

127-
Arguments
128-
---------
129-
model : Pretrained
130-
The SpeechBrain pretrained model.
131-
dataset : Dataset
132-
The dataset to evaluate.
133-
device : str
134-
The device to run the model on.
135-
batch_size : int
136-
The batch size to use.
137-
verbose : bool, optional
138-
Whether to print progress information.
103+
# normalize transcriptions with English normalizer
104+
batch["predictions"] = [data_utils.normalizer(pred) for pred in predictions]
105+
batch["references"] = batch["norm_text"]
106+
return batch
139107

140-
Returns
141-
-------
142-
references : List
143-
A list of references.
144-
predictions : List
145-
A list of predictions.
146-
"""
147-
references = []
148-
predictions = []
149-
buffer = []
150-
for sample in tqdm(
151-
dataset_iterator(dataset),
152-
desc="Evaluating: Sample id",
153-
unit="",
154-
disable=not verbose,
155-
):
156-
buffer.append(sample)
157-
references.append(sample["reference"])
158-
if len(buffer) == batch_size:
159-
evaluate_batch(model, buffer, predictions, device)
160-
161-
if len(buffer) > 0:
162-
evaluate_batch(model, buffer, predictions, device)
163-
164-
return references, predictions
165108

109+
if args.warmup_steps is not None:
110+
dataset = data_utils.load_data(args)
111+
dataset = data_utils.prepare_data(dataset)
166112

167-
def main(args):
168-
"""Run the evaluation script."""
169-
if args.device == -1:
170-
device = "cpu"
171-
else:
172-
device = f"cuda:{args.device}"
113+
num_warmup_samples = args.warmup_steps * args.batch_size
114+
if args.streaming:
115+
warmup_dataset = dataset.take(num_warmup_samples)
116+
else:
117+
warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
118+
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True))
173119

174-
asr_model = get_model(
175-
args.source, args.speechbrain_pretrained_class_name, device=device
176-
)
120+
for _ in tqdm(warmup_dataset, desc="Warming up..."):
121+
continue
177122

178123
dataset = data_utils.load_data(args)
179-
180124
if args.max_eval_samples is not None and args.max_eval_samples > 0:
181-
print(f"Subsampling dataset to first {args.max_eval_samples} samples !")
182-
dataset = dataset.take(args.max_eval_samples)
183-
125+
print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
126+
if args.streaming:
127+
dataset = dataset.take(args.max_eval_samples)
128+
else:
129+
dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
184130
dataset = data_utils.prepare_data(dataset)
185131

186-
predictions = []
187-
references = []
188-
189-
references, predictions = evaluate_dataset(
190-
asr_model, dataset, device, args.batch_size, verbose=True
132+
dataset = dataset.map(
133+
benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
191134
)
192135

193-
# Write manifest results
136+
all_results = {
137+
"audio_length_s": [],
138+
"transcription_time_s": [],
139+
"predictions": [],
140+
"references": [],
141+
}
142+
result_iter = iter(dataset)
143+
for result in tqdm(result_iter, desc="Samples..."):
144+
for key in all_results:
145+
all_results[key].append(result[key])
146+
147+
# Write manifest results (WER and RTFX)
194148
manifest_path = data_utils.write_manifest(
195-
references, predictions, args.source, args.dataset_path, args.dataset, args.split
149+
all_results["references"],
150+
all_results["predictions"],
151+
args.model_id,
152+
args.dataset_path,
153+
args.dataset,
154+
args.split,
155+
audio_length=all_results["audio_length_s"],
156+
transcription_time=all_results["transcription_time_s"],
196157
)
197158
print("Results saved at path:", os.path.abspath(manifest_path))
198-
159+
199160
wer_metric = evaluate.load("wer")
200-
wer = wer_metric.compute(references=references, predictions=predictions)
161+
wer = wer_metric.compute(
162+
references=all_results["references"], predictions=all_results["predictions"]
163+
)
201164
wer = round(100 * wer, 2)
202-
203-
print("WER:", wer, "%")
165+
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
166+
print("WER:", wer, "%", "RTFx:", rtfx)
204167

205168

206169
if __name__ == "__main__":
@@ -263,6 +226,12 @@ def main(args):
263226
action="store_false",
264227
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
265228
)
229+
parser.add_argument(
230+
"--warmup_steps",
231+
type=int,
232+
default=5,
233+
help="Number of warm-up steps to run before launching the timed runs.",
234+
)
266235
args = parser.parse_args()
267236
parser.set_defaults(streaming=True)
268237

0 commit comments

Comments
 (0)