Skip to content

Commit d58797c

Browse files
author
sanchit-gandhi
committed
pipeline -> model api
1 parent 94d0d6a commit d58797c

File tree

1 file changed

+54
-25
lines changed

1 file changed

+54
-25
lines changed

transformers/run_eval.py

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33

44
import torch
5-
from transformers import pipeline
5+
from transformers import AutoConfig, AutoModelForSpeechSeq2Seq, AutoModelForCTC, AutoProcessor, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
66
import evaluate
77
from normalizer import data_utils
88
import time
@@ -12,22 +12,18 @@
1212

1313

1414
def main(args):
15-
asr_pipe = pipeline(
16-
"automatic-speech-recognition",
17-
model=args.model_id,
18-
device=args.device,
19-
batch_size=args.batch_size,
20-
torch_dtype=torch.float16,
21-
)
15+
config = AutoConfig.from_pretrained(args.model_id)
16+
cls_model = AutoModelForSpeechSeq2Seq if type(config) in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING else AutoModelForCTC
17+
model = cls_model.from_pretrained(args.model_id, torch_dtype=torch.float16).to(args.device)
18+
processor = AutoProcessor.from_pretrained(args.model_id)
19+
model_input_name = processor.model_input_names[0]
2220

23-
if asr_pipe.model.can_generate():
21+
if model.can_generate():
2422
gen_kwargs = {"max_new_tokens": 256}
2523
# for multilingual Whisper-checkpoints we see a definitive WER boost by setting the language and task args
26-
if getattr(asr_pipe.model.generation_config, "is_multilingual"):
24+
if getattr(model.generation_config, "is_multilingual"):
2725
gen_kwargs["language"] = "en"
2826
gen_kwargs["task"] = "transcribe"
29-
else:
30-
gen_kwargs = None
3127

3228
dataset = data_utils.load_data(args)
3329

@@ -38,19 +34,52 @@ def main(args):
3834
dataset = data_utils.prepare_data(dataset)
3935

4036
def benchmark(batch):
41-
# get audio stats
42-
audio = [sample["array"] for sample in batch["audio"]]
43-
batch["audio_length"] = [len(sample) / 16_000 for sample in audio]
44-
minibatch_size = len(audio)
37+
# Load audio inputs
38+
audios = [audio["array"] for audio in batch["audio"]]
39+
minibatch_size = len(audios)
4540

46-
# timing step
41+
# START TIMING
4742
start_time = time.time()
48-
result = asr_pipe(batch["audio"], generate_kwargs=gen_kwargs)
43+
44+
# 1. Pre-Processing
45+
if not model.can_generate() or len(audios[0]) > processor.feature_extractor.n_samples:
46+
# 1.1 Either CTC pre-processing (normalize to mean 0, std 1), or long-form Whisper processing
47+
inputs = processor(
48+
audios,
49+
sampling_rate=16_000,
50+
truncation=False,
51+
padding="longest",
52+
return_tensors="pt",
53+
return_attention_mask=True,
54+
)
55+
else:
56+
# 1.2 Standard Whisper processing: pad audios to 30-seconds and converted to log-mel
57+
inputs = processor(audios, sampling_rate=16_000, return_tensors="pt")
58+
59+
inputs = inputs.to(args.device)
60+
inputs[model_input_name] = inputs[model_input_name].to(torch.float16)
61+
62+
# 2. Model Inference
63+
if model.can_generate():
64+
# 2.1 Auto-regressive generation for encoder-decoder models
65+
pred_ids = model.generate(**inputs, **gen_kwargs)
66+
else:
67+
# 2.2. Single forward pass for CTC
68+
with torch.no_grad():
69+
logits = model(**inputs)
70+
pred_ids = logits.argmax(-1)
71+
72+
# 3. Post-processing: convert token ids to text transcription
73+
pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True)
74+
75+
# END TIMING
76+
runtime = time.time() - start_time
77+
4978
# normalize by minibatch size since we want the per-sample time
50-
batch["transcription_time"] = minibatch_size * [(time.time() - start_time) / minibatch_size]
79+
batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]
5180

5281
# normalize transcriptions with English normalizer
53-
batch["predictions"] = [data_utils.normalizer(pred["text"]) for pred in result]
82+
batch["predictions"] = [data_utils.normalizer(pred) for pred in pred_text]
5483
batch["references"] = batch["norm_text"]
5584
return batch
5685

@@ -59,8 +88,8 @@ def benchmark(batch):
5988
)
6089

6190
all_results = {
62-
"audio_length": [],
63-
"transcription_time": [],
91+
"audio_length_s": [],
92+
"transcription_time_s": [],
6493
"predictions": [],
6594
"references": [],
6695
}
@@ -77,16 +106,16 @@ def benchmark(batch):
77106
args.dataset_path,
78107
args.dataset,
79108
args.split,
80-
audio_length=all_results["audio_length"],
81-
transcription_time=all_results["transcription_time"],
109+
audio_length=all_results["audio_length_s"],
110+
transcription_time=all_results["transcription_time_s"],
82111
)
83112
print("Results saved at path:", os.path.abspath(manifest_path))
84113

85114
wer = wer_metric.compute(
86115
references=all_results["references"], predictions=all_results["predictions"]
87116
)
88117
wer = round(100 * wer, 2)
89-
rtfx = round(sum(all_results["audio_length"]) / sum(all_results["transcription_time"]), 2)
118+
rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
90119
print("WER:", wer, "%", "RTFx:", rtfx)
91120

92121

0 commit comments

Comments
 (0)