11import argparse
22import os
33import torch
4- from transformers import AutoConfig , AutoModelForSpeechSeq2Seq , AutoProcessor , PreTrainedTokenizerFast
4+ from transformers import MoonshineForConditionalGeneration , AutoProcessor
5+
56import evaluate
67from normalizer import data_utils
78import time
1213torch .set_float32_matmul_precision ('high' )
1314
1415def main (args ):
15- config = AutoConfig . from_pretrained ( args . model_id , trust_remote_code = True )
16- model = AutoModelForSpeechSeq2Seq .from_pretrained (args .model_id , torch_dtype = torch . bfloat16 , trust_remote_code = True ).to (args .device )
17- tokenizer = PreTrainedTokenizerFast .from_pretrained (args .model_id , trust_remote_code = True )
16+ torch_dtype = torch . float16 if torch . cuda . is_available () else torch . float32
17+ model = MoonshineForConditionalGeneration .from_pretrained (args .model_id ).to (args .device ). to ( torch_dtype )
18+ processor = AutoProcessor .from_pretrained (args .model_id )
1819
1920 if args .torch_compile :
2021 model .forward = torch .compile (model .forward , mode = args .compile_mode , fullgraph = True )
@@ -30,16 +31,30 @@ def benchmark(batch, min_new_tokens=None):
3031 # START TIMING
3132 start_time = time .time ()
3233
33- np_arr = np .array (audios )
34- input_tensor = torch .FloatTensor (np_arr )
35- moonshine_min_input_size = 1024
36- padding = moonshine_min_input_size - input_tensor .size ()[1 ]
37- if padding > 0 :
38- input_tensor = torch .nn .functional .pad (input_tensor , (0 , padding ))
39- pred_ids = model (input_tensor .to (args .device ).to (torch .bfloat16 ))
34+ # 1. Pre-Processing
35+ # 1.1 Pad audios to max batch size if using torch compile to prevent re-compilations
36+ padding_size = 0
37+ if minibatch_size != args .batch_size and args .torch_compile :
38+ padding_size = args .batch_size - minibatch_size
39+ padding_audios = [audios [- 1 ] for _ in range (padding_size )]
40+ audios .extend (padding_audios )
41+
42+ inputs = processor (audios , return_tensors = "pt" , padding = True , sampling_rate = 16000 ).to (args .device ).to (torch_dtype )
43+
44+ # Create a mask for output tokens to limit length based on input audio clip length.
45+ # Add 2 to token limits to account for <sot> and <eot>.
46+ token_generation_limits = [len (clip ) * 6.5 // 16000 + 2 for clip in audios ]
47+ max_new_tokens = torch .tensor (token_generation_limits ).reshape ((- 1 , 1 )).to (args .device )
48+
49+ pred_ids = model .generate (** inputs , max_new_tokens = max_new_tokens .max ())
50+ output_mask = torch .arange (pred_ids .shape [- 1 ]).repeat ((pred_ids .shape [0 ], 1 )).to (args .device )
51+ output_mask = output_mask > max_new_tokens
52+
53+ eot_token = model .config .eos_token_id
54+ pred_ids .masked_fill (output_mask , eot_token )
4055
4156 # 3.2 Convert token ids to text transcription
42- pred_text = tokenizer .batch_decode (pred_ids , skip_special_tokens = True )
57+ pred_text = processor .batch_decode (pred_ids , skip_special_tokens = True )
4358
4459 # END TIMING
4560 runtime = time .time () - start_time
@@ -48,6 +63,7 @@ def benchmark(batch, min_new_tokens=None):
4863 batch ["transcription_time_s" ] = minibatch_size * [runtime / minibatch_size ]
4964
5065 # normalize transcriptions with English normalizer
66+ pred_text = pred_text if padding_size == 0 else pred_text [:- padding_size ]
5167 batch ["predictions" ] = [data_utils .normalizer (pred ) for pred in pred_text ]
5268 batch ["references" ] = batch ["norm_text" ]
5369 return batch
0 commit comments