4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import argparse
7
8
import logging
8
9
9
10
import torch
@@ -168,8 +169,9 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
168
169
return log_spec .unsqueeze (0 )
169
170
170
171
171
- def export_processor ():
172
- model = WhisperAudioProcessor ()
172
+ def export_processor (model = None , output_file = "whisper_preprocess.pte" ):
173
+ if model is None :
174
+ model = WhisperAudioProcessor ()
173
175
audio_tensor = torch .randn (480000 )
174
176
chunk_tensor = audio_tensor [:93680 ]
175
177
with torch .no_grad ():
@@ -191,15 +193,61 @@ def export_processor():
191
193
192
194
# to executorch
193
195
exec_prog = edge .to_executorch ()
194
- output_file = "whisper_preprocess.pte"
195
196
with open (output_file , "wb" ) as file :
196
197
exec_prog .write_to_file (file )
197
198
198
199
logging .debug ("Done" )
199
200
200
201
201
202
def main ():
202
- export_processor ()
203
+ parser = argparse .ArgumentParser (
204
+ description = "Export WhisperAudioProcessor to ExecutorTorch"
205
+ )
206
+ parser .add_argument (
207
+ "--feature_size" ,
208
+ type = int ,
209
+ default = 80 ,
210
+ help = "The feature dimension of the extracted features" ,
211
+ )
212
+ parser .add_argument (
213
+ "--sampling_rate" ,
214
+ type = int ,
215
+ default = 16000 ,
216
+ help = "The sampling rate at which audio files should be digitalized (Hz)" ,
217
+ )
218
+ parser .add_argument (
219
+ "--hop_length" ,
220
+ type = int ,
221
+ default = 160 ,
222
+ help = "Length of overlapping windows for STFT" ,
223
+ )
224
+ parser .add_argument (
225
+ "--chunk_length" ,
226
+ type = int ,
227
+ default = 30 ,
228
+ help = "Maximum number of chunks of sampling_rate samples" ,
229
+ )
230
+ parser .add_argument (
231
+ "--n_fft" , type = int , default = 400 , help = "Size of the Fourier transform"
232
+ )
233
+ parser .add_argument (
234
+ "--output_file" ,
235
+ type = str ,
236
+ default = "whisper_preprocess.pte" ,
237
+ help = "Output file path for the exported model" ,
238
+ )
239
+
240
+ args = parser .parse_args ()
241
+
242
+ model = WhisperAudioProcessor (
243
+ feature_size = args .feature_size ,
244
+ sampling_rate = args .sampling_rate ,
245
+ hop_length = args .hop_length ,
246
+ chunk_length = args .chunk_length ,
247
+ n_fft = args .n_fft ,
248
+ )
249
+
250
+ export_processor (model , args .output_file )
203
251
204
252
205
253
if __name__ == "__main__" :
0 commit comments