22import os
33
44import torch
5- from transformers import pipeline
5+ from transformers import AutoConfig , AutoModelForSpeechSeq2Seq , AutoModelForCTC , AutoProcessor , MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
66import evaluate
77from normalizer import data_utils
88import time
1212
1313
1414def main (args ):
15- asr_pipe = pipeline (
16- "automatic-speech-recognition" ,
17- model = args .model_id ,
18- device = args .device ,
19- batch_size = args .batch_size ,
20- torch_dtype = torch .float16 ,
21- )
15+ config = AutoConfig .from_pretrained (args .model_id )
16+ cls_model = AutoModelForSpeechSeq2Seq if type (config ) in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING else AutoModelForCTC
17+ model = cls_model .from_pretrained (args .model_id , torch_dtype = torch .float16 ).to (args .device )
18+ processor = AutoProcessor .from_pretrained (args .model_id )
19+ model_input_name = processor .model_input_names [0 ]
2220
23- if asr_pipe . model .can_generate ():
21+ if model .can_generate ():
2422 gen_kwargs = {"max_new_tokens" : 256 }
2523 # for multilingual Whisper-checkpoints we see a definitive WER boost by setting the language and task args
26- if getattr (asr_pipe . model .generation_config , "is_multilingual" ):
24+ if getattr (model .generation_config , "is_multilingual" ):
2725 gen_kwargs ["language" ] = "en"
2826 gen_kwargs ["task" ] = "transcribe"
29- else :
30- gen_kwargs = None
3127
3228 dataset = data_utils .load_data (args )
3329
@@ -38,19 +34,52 @@ def main(args):
3834 dataset = data_utils .prepare_data (dataset )
3935
4036 def benchmark (batch ):
41- # get audio stats
42- audio = [sample ["array" ] for sample in batch ["audio" ]]
43- batch ["audio_length" ] = [len (sample ) / 16_000 for sample in audio ]
44- minibatch_size = len (audio )
37+ # Load audio inputs
38+ audios = [audio ["array" ] for audio in batch ["audio" ]]
39+ minibatch_size = len (audios )
4540
46- # timing step
41+ # START TIMING
4742 start_time = time .time ()
48- result = asr_pipe (batch ["audio" ], generate_kwargs = gen_kwargs )
43+
44+ # 1. Pre-Processing
45+ if not model .can_generate () or len (audios [0 ]) > processor .feature_extractor .n_samples :
46+ # 1.1 Either CTC pre-processing (normalize to mean 0, std 1), or long-form Whisper processing
47+ inputs = processor (
48+ audios ,
49+ sampling_rate = 16_000 ,
50+ truncation = False ,
51+ padding = "longest" ,
52+ return_tensors = "pt" ,
53+ return_attention_mask = True ,
54+ )
55+ else :
56+ # 1.2 Standard Whisper processing: pad audios to 30-seconds and converted to log-mel
57+ inputs = processor (audios , sampling_rate = 16_000 , return_tensors = "pt" )
58+
59+ inputs = inputs .to (args .device )
60+ inputs [model_input_name ] = inputs [model_input_name ].to (torch .float16 )
61+
62+ # 2. Model Inference
63+ if model .can_generate ():
64+ # 2.1 Auto-regressive generation for encoder-decoder models
65+ pred_ids = model .generate (** inputs , ** gen_kwargs )
66+ else :
67+ # 2.2. Single forward pass for CTC
68+ with torch .no_grad ():
69+ logits = model (** inputs )
70+ pred_ids = logits .argmax (- 1 )
71+
72+ # 3. Post-processing: convert token ids to text transcription
73+ pred_text = processor .batch_decode (pred_ids , skip_special_tokens = True )
74+
75+ # END TIMING
76+ runtime = time .time () - start_time
77+
4978 # normalize by minibatch size since we want the per-sample time
50- batch ["transcription_time " ] = minibatch_size * [( time . time () - start_time ) / minibatch_size ]
79+ batch ["transcription_time_s " ] = minibatch_size * [runtime / minibatch_size ]
5180
5281 # normalize transcriptions with English normalizer
53- batch ["predictions" ] = [data_utils .normalizer (pred [ "text" ] ) for pred in result ]
82+ batch ["predictions" ] = [data_utils .normalizer (pred ) for pred in pred_text ]
5483 batch ["references" ] = batch ["norm_text" ]
5584 return batch
5685
@@ -59,8 +88,8 @@ def benchmark(batch):
5988 )
6089
6190 all_results = {
62- "audio_length " : [],
63- "transcription_time " : [],
91+ "audio_length_s " : [],
92+ "transcription_time_s " : [],
6493 "predictions" : [],
6594 "references" : [],
6695 }
@@ -77,16 +106,16 @@ def benchmark(batch):
77106 args .dataset_path ,
78107 args .dataset ,
79108 args .split ,
80- audio_length = all_results ["audio_length " ],
81- transcription_time = all_results ["transcription_time " ],
109+ audio_length = all_results ["audio_length_s " ],
110+ transcription_time = all_results ["transcription_time_s " ],
82111 )
83112 print ("Results saved at path:" , os .path .abspath (manifest_path ))
84113
85114 wer = wer_metric .compute (
86115 references = all_results ["references" ], predictions = all_results ["predictions" ]
87116 )
88117 wer = round (100 * wer , 2 )
89- rtfx = round (sum (all_results ["audio_length " ]) / sum (all_results ["transcription_time " ]), 2 )
118+ rtfx = round (sum (all_results ["audio_length_s " ]) / sum (all_results ["transcription_time_s " ]), 2 )
90119 print ("WER:" , wer , "%" , "RTFx:" , rtfx )
91120
92121
0 commit comments