Skip to content

Commit 40d97dc

Browse files
author
sanchit-gandhi
committed
finalise
1 parent 523c249 commit 40d97dc

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

transformers/run_eval.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import argparse
22
import os
3-
43
import torch
54
from torch.nn.attention import sdpa_kernel, SDPBackend
65
from transformers import AutoConfig, AutoModelForSpeechSeq2Seq, AutoModelForCTC, AutoProcessor, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
@@ -10,10 +9,7 @@
109
from tqdm import tqdm
1110

1211
wer_metric = evaluate.load("wer")
13-
1412
torch.set_float32_matmul_precision('high')
15-
torch._logging.set_logs(graph_breaks=True, recompiles=True)
16-
1713

1814
def main(args):
1915
config = AutoConfig.from_pretrained(args.model_id)
@@ -23,19 +19,21 @@ def main(args):
2319
model_input_name = processor.model_input_names[0]
2420

2521
if model.can_generate():
26-
gen_kwargs = {"max_new_tokens": 256}
22+
gen_kwargs = {"max_new_tokens": args.max_new_tokens}
2723
# for multilingual Whisper-checkpoints we see a definitive WER boost by setting the language and task args
2824
if getattr(model.generation_config, "is_multilingual"):
2925
gen_kwargs["language"] = "en"
3026
gen_kwargs["task"] = "transcribe"
27+
elif args.max_new_tokens:
28+
raise ValueError("`max_new_tokens` should only be set for auto-regressive models, but got a CTC model.")
3129

3230
if args.torch_compile:
3331
model.forward = torch.compile(model.forward, mode=args.compile_mode, fullgraph=True)
3432
if model.can_generate():
3533
# enable static k/v cache for autoregressive models
3634
model.generation_config.cache_implementation = "static"
3735

38-
def benchmark(batch):
36+
def benchmark(batch, min_new_tokens=None):
3937
# Load audio inputs
4038
audios = [audio["array"] for audio in batch["audio"]]
4139
minibatch_size = len(audios)
@@ -72,7 +70,7 @@ def benchmark(batch):
7270
with sdpa_kernel(SDPBackend.MATH if args.torch_compile else SDPBackend.FLASH_ATTENTION):
7371
if model.can_generate():
7472
# 2.1 Auto-regressive generation for encoder-decoder models
75-
pred_ids = model.generate(**inputs, **gen_kwargs)
73+
pred_ids = model.generate(**inputs, **gen_kwargs, min_new_tokens=min_new_tokens)
7674
else:
7775
# 2.2. Single forward pass for CTC
7876
with torch.no_grad():
@@ -107,7 +105,7 @@ def benchmark(batch):
107105
warmup_dataset = dataset.take(num_warmup_samples)
108106
else:
109107
warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
110-
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True))
108+
warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"min_new_tokens": args.max_new_tokens}))
111109

112110
for _ in tqdm(warmup_dataset, desc="Warming up..."):
113111
continue
@@ -209,6 +207,12 @@ def benchmark(batch):
209207
action="store_false",
210208
help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.",
211209
)
210+
parser.add_argument(
211+
"--max_new_tokens",
212+
type=int,
213+
default=None,
214+
help="Maximum number of tokens to generate (for auto-regressive models).",
215+
)
212216
parser.add_argument(
213217
"--torch_compile",
214218
action="store_true",

0 commit comments

Comments
 (0)