11import argparse
22import os
3-
43import torch
54from torch .nn .attention import sdpa_kernel , SDPBackend
65from transformers import AutoConfig , AutoModelForSpeechSeq2Seq , AutoModelForCTC , AutoProcessor , MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
109from tqdm import tqdm
1110
1211wer_metric = evaluate .load ("wer" )
13-
1412torch .set_float32_matmul_precision ('high' )
15- torch ._logging .set_logs (graph_breaks = True , recompiles = True )
16-
1713
1814def main (args ):
1915 config = AutoConfig .from_pretrained (args .model_id )
@@ -23,19 +19,21 @@ def main(args):
2319 model_input_name = processor .model_input_names [0 ]
2420
2521 if model .can_generate ():
26- gen_kwargs = {"max_new_tokens" : 256 }
22+ gen_kwargs = {"max_new_tokens" : args . max_new_tokens }
2723 # for multilingual Whisper-checkpoints we see a definitive WER boost by setting the language and task args
2824 if getattr (model .generation_config , "is_multilingual" ):
2925 gen_kwargs ["language" ] = "en"
3026 gen_kwargs ["task" ] = "transcribe"
27+ elif args .max_new_tokens :
28+ raise ValueError ("`max_new_tokens` should only be set for auto-regressive models, but got a CTC model." )
3129
3230 if args .torch_compile :
3331 model .forward = torch .compile (model .forward , mode = args .compile_mode , fullgraph = True )
3432 if model .can_generate ():
3533 # enable static k/v cache for autoregressive models
3634 model .generation_config .cache_implementation = "static"
3735
38- def benchmark (batch ):
36+ def benchmark (batch , min_new_tokens = None ):
3937 # Load audio inputs
4038 audios = [audio ["array" ] for audio in batch ["audio" ]]
4139 minibatch_size = len (audios )
@@ -72,7 +70,7 @@ def benchmark(batch):
7270 with sdpa_kernel (SDPBackend .MATH if args .torch_compile else SDPBackend .FLASH_ATTENTION ):
7371 if model .can_generate ():
7472 # 2.1 Auto-regressive generation for encoder-decoder models
75- pred_ids = model .generate (** inputs , ** gen_kwargs )
73+ pred_ids = model .generate (** inputs , ** gen_kwargs , min_new_tokens = min_new_tokens )
7674 else :
7775 # 2.2. Single forward pass for CTC
7876 with torch .no_grad ():
@@ -107,7 +105,7 @@ def benchmark(batch):
107105 warmup_dataset = dataset .take (num_warmup_samples )
108106 else :
109107 warmup_dataset = dataset .select (range (min (num_warmup_samples , len (dataset ))))
110- warmup_dataset = iter (warmup_dataset .map (benchmark , batch_size = args .batch_size , batched = True ))
108+ warmup_dataset = iter (warmup_dataset .map (benchmark , batch_size = args .batch_size , batched = True , fn_kwargs = { "min_new_tokens" : args . max_new_tokens } ))
111109
112110 for _ in tqdm (warmup_dataset , desc = "Warming up..." ):
113111 continue
@@ -209,6 +207,12 @@ def benchmark(batch):
209207 action = "store_false" ,
210208 help = "Choose whether you'd like to download the entire dataset or stream it during the evaluation." ,
211209 )
210+ parser .add_argument (
211+ "--max_new_tokens" ,
212+ type = int ,
213+ default = None ,
214+ help = "Maximum number of tokens to generate (for auto-regressive models)." ,
215+ )
212216 parser .add_argument (
213217 "--torch_compile" ,
214218 action = "store_true" ,
0 commit comments