22import os
33
44import torch
5+ from torch .nn .attention import sdpa_kernel , SDPBackend
56from transformers import AutoConfig , AutoModelForSpeechSeq2Seq , AutoModelForCTC , AutoProcessor , MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
67import evaluate
78from normalizer import data_utils
1011
1112wer_metric = evaluate .load ("wer" )
1213
14+ torch .set_float32_matmul_precision ('high' )
15+ torch ._logging .set_logs (graph_breaks = True , recompiles = True )
16+
1317
1418def main (args ):
1519 config = AutoConfig .from_pretrained (args .model_id )
1620 cls_model = AutoModelForSpeechSeq2Seq if type (config ) in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING else AutoModelForCTC
17- model = cls_model .from_pretrained (args .model_id , torch_dtype = torch .float16 ).to (args .device )
21+ model = cls_model .from_pretrained (args .model_id , torch_dtype = torch .bfloat16 , attn_implementation = "sdpa" ).to (args .device )
1822 processor = AutoProcessor .from_pretrained (args .model_id )
1923 model_input_name = processor .model_input_names [0 ]
2024
@@ -25,13 +29,11 @@ def main(args):
2529 gen_kwargs ["language" ] = "en"
2630 gen_kwargs ["task" ] = "transcribe"
2731
28- dataset = data_utils .load_data (args )
29-
30- if args .max_eval_samples is not None and args .max_eval_samples > 0 :
31- print (f"Subsampling dataset to first { args .max_eval_samples } samples!" )
32- dataset = dataset .take (args .max_eval_samples )
33-
34- dataset = data_utils .prepare_data (dataset )
32+ if args .torch_compile :
33+ model .forward = torch .compile (model .forward , mode = args .compile_mode , fullgraph = True )
34+ if model .can_generate ():
35+ # enable static k/v cache for autoregressive models
36+ model .generation_config .cache_implementation = "static"
3537
3638 def benchmark (batch ):
3739 # Load audio inputs
@@ -42,8 +44,15 @@ def benchmark(batch):
4244 start_time = time .time ()
4345
4446 # 1. Pre-Processing
45- if not model .can_generate () or len (audios [0 ]) > processor .feature_extractor .n_samples :
46- # 1.1 Either CTC pre-processing (normalize to mean 0, std 1), or long-form Whisper processing
47+ # 1.1 Pad audios to max batch size if using torch compile to prevent re-compilations
48+ padding_size = None
49+ if minibatch_size != args .batch_size and args .torch_compile :
50+ padding_size = args .batch_size - minibatch_size
51+ padding_audios = [audios [- 1 ] for _ in range (padding_size )]
52+ audios .extend (padding_audios )
53+
54+ if not model .can_generate (): #or len(audios[0]) > processor.feature_extractor.n_samples:
55+ # 1.2 Either CTC pre-processing (normalize to mean 0, std 1), or long-form Whisper processing
4756 inputs = processor (
4857 audios ,
4958 sampling_rate = 16_000 ,
@@ -53,23 +62,29 @@ def benchmark(batch):
5362 return_attention_mask = True ,
5463 )
5564 else :
56- # 1.2 Standard Whisper processing: pad audios to 30-seconds and converted to log-mel
57- inputs = processor (audios , sampling_rate = 16_000 , return_tensors = "pt" )
65+ # 1.3 Standard Whisper processing: pad audios to 30-seconds and converted to log-mel
66+ inputs = processor (audios , sampling_rate = 16_000 , return_tensors = "pt" , device = args . device )
5867
5968 inputs = inputs .to (args .device )
60- inputs [model_input_name ] = inputs [model_input_name ].to (torch .float16 )
69+ inputs [model_input_name ] = inputs [model_input_name ].to (torch .bfloat16 )
6170
6271 # 2. Model Inference
63- if model .can_generate ():
64- # 2.1 Auto-regressive generation for encoder-decoder models
65- pred_ids = model .generate (** inputs , ** gen_kwargs )
66- else :
67- # 2.2. Single forward pass for CTC
68- with torch .no_grad ():
69- logits = model (** inputs )
70- pred_ids = logits .argmax (- 1 )
71-
72- # 3. Post-processing: convert token ids to text transcription
72+ with sdpa_kernel (SDPBackend .MATH if args .torch_compile else SDPBackend .FLASH_ATTENTION ):
73+ if model .can_generate ():
74+ # 2.1 Auto-regressive generation for encoder-decoder models
75+ pred_ids = model .generate (** inputs , ** gen_kwargs )
76+ else :
77+ # 2.2. Single forward pass for CTC
78+ with torch .no_grad ():
79+ logits = model (** inputs ).logits
80+ pred_ids = logits .argmax (- 1 )
81+
82+ # 3. Post-processing
83+ # 3.1 Strip padded ids from predictions
84+ if padding_size is not None :
85+ pred_ids = pred_ids [:- padding_size , ...]
86+
87+ # 3.2 Convert token ids to text transcription
7388 pred_text = processor .batch_decode (pred_ids , skip_special_tokens = True )
7489
7590 # END TIMING
@@ -83,8 +98,31 @@ def benchmark(batch):
8398 batch ["references" ] = batch ["norm_text" ]
8499 return batch
85100
101+ if args .warmup_steps is not None :
102+ dataset = data_utils .load_data (args )
103+ dataset = data_utils .prepare_data (dataset )
104+
105+ num_warmup_samples = args .warmup_steps * args .batch_size
106+ if args .streaming :
107+ warmup_dataset = dataset .take (num_warmup_samples )
108+ else :
109+ warmup_dataset = dataset .select (range (min (num_warmup_samples , len (dataset ))))
110+ warmup_dataset = iter (warmup_dataset .map (benchmark , batch_size = args .batch_size , batched = True ))
111+
112+ for _ in tqdm (warmup_dataset , desc = "Warming up..." ):
113+ continue
114+
115+ dataset = data_utils .load_data (args )
116+ if args .max_eval_samples is not None and args .max_eval_samples > 0 :
117+ print (f"Subsampling dataset to first { args .max_eval_samples } samples!" )
118+ if args .streaming :
119+ dataset = dataset .take (args .max_eval_samples )
120+ else :
121+ dataset = dataset .select (range (min (args .max_eval_samples , len (dataset ))))
122+ dataset = data_utils .prepare_data (dataset )
123+
86124 dataset = dataset .map (
87- benchmark , batch_size = args .batch_size , batched = True , remove_columns = ["audio" ]
125+ benchmark , batch_size = args .batch_size , batched = True , remove_columns = ["audio" ],
88126 )
89127
90128 all_results = {
@@ -94,7 +132,7 @@ def benchmark(batch):
94132 "references" : [],
95133 }
96134 result_iter = iter (dataset )
97- for result in tqdm (result_iter , desc = "Samples" ):
135+ for result in tqdm (result_iter , desc = "Samples... " ):
98136 for key in all_results :
99137 all_results [key ].append (result [key ])
100138
@@ -171,6 +209,23 @@ def benchmark(batch):
171209 action = "store_false" ,
172210 help = "Choose whether you'd like to download the entire dataset or stream it during the evaluation." ,
173211 )
212+ parser .add_argument (
213+ "--torch_compile" ,
214+ action = "store_true" ,
215+ help = "Whether to JIT compile the forward pass of the model." ,
216+ )
217+ parser .add_argument (
218+ "--compile_mode" ,
219+ type = str ,
220+ default = "max-autotune" ,
221+ help = "Mode for torch compiling model forward pass. Can be either 'default', 'reduce-overhead', 'max-autotune' or 'max-autotune-no-cudagraphs'." ,
222+ )
223+ parser .add_argument (
224+ "--warmup_steps" ,
225+ type = int ,
226+ default = 10 ,
227+ help = "Number of warm-up steps to run before launching the timed runs." ,
228+ )
174229 args = parser .parse_args ()
175230 parser .set_defaults (streaming = False )
176231
0 commit comments