44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import argparse
78import logging
89
910import torch
@@ -168,8 +169,9 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
168169 return log_spec .unsqueeze (0 )
169170
170171
171- def export_processor ():
172- model = WhisperAudioProcessor ()
172+ def export_processor (model = None , output_file = "whisper_preprocess.pte" ):
173+ if model is None :
174+ model = WhisperAudioProcessor ()
173175 audio_tensor = torch .randn (480000 )
174176 chunk_tensor = audio_tensor [:93680 ]
175177 with torch .no_grad ():
@@ -191,15 +193,61 @@ def export_processor():
191193
192194 # to executorch
193195 exec_prog = edge .to_executorch ()
194- output_file = "whisper_preprocess.pte"
195196 with open (output_file , "wb" ) as file :
196197 exec_prog .write_to_file (file )
197198
198199 logging .debug ("Done" )
199200
200201
201202def main ():
202- export_processor ()
203+ parser = argparse .ArgumentParser (
204+ description = "Export WhisperAudioProcessor to ExecutorTorch"
205+ )
206+ parser .add_argument (
207+ "--feature_size" ,
208+ type = int ,
209+ default = 80 ,
210+ help = "The feature dimension of the extracted features" ,
211+ )
212+ parser .add_argument (
213+ "--sampling_rate" ,
214+ type = int ,
215+ default = 16000 ,
216+ help = "The sampling rate at which audio files should be digitalized (Hz)" ,
217+ )
218+ parser .add_argument (
219+ "--hop_length" ,
220+ type = int ,
221+ default = 160 ,
222+ help = "Length of overlapping windows for STFT" ,
223+ )
224+ parser .add_argument (
225+ "--chunk_length" ,
226+ type = int ,
227+ default = 30 ,
228+ help = "Maximum number of chunks of sampling_rate samples" ,
229+ )
230+ parser .add_argument (
231+ "--n_fft" , type = int , default = 400 , help = "Size of the Fourier transform"
232+ )
233+ parser .add_argument (
234+ "--output_file" ,
235+ type = str ,
236+ default = "whisper_preprocess.pte" ,
237+ help = "Output file path for the exported model" ,
238+ )
239+
240+ args = parser .parse_args ()
241+
242+ model = WhisperAudioProcessor (
243+ feature_size = args .feature_size ,
244+ sampling_rate = args .sampling_rate ,
245+ hop_length = args .hop_length ,
246+ chunk_length = args .chunk_length ,
247+ n_fft = args .n_fft ,
248+ )
249+
250+ export_processor (model , args .output_file )
203251
204252
205253if __name__ == "__main__" :
0 commit comments