Skip to content

Commit f4ec01a

Browse files
authored
Scriptify mel spectrogram processor (#13961)
Make mel spectrogram processor a runnable script.
1 parent 14d0745 commit f4ec01a

File tree

1 file changed

+52
-4
lines changed

1 file changed

+52
-4
lines changed

extension/audio/mel_spectrogram.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import argparse
78
import logging
89

910
import torch
@@ -168,8 +169,9 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
168169
return log_spec.unsqueeze(0)
169170

170171

171-
def export_processor():
172-
model = WhisperAudioProcessor()
172+
def export_processor(model=None, output_file="whisper_preprocess.pte"):
173+
if model is None:
174+
model = WhisperAudioProcessor()
173175
audio_tensor = torch.randn(480000)
174176
chunk_tensor = audio_tensor[:93680]
175177
with torch.no_grad():
@@ -191,15 +193,61 @@ def export_processor():
191193

192194
# to executorch
193195
exec_prog = edge.to_executorch()
194-
output_file = "whisper_preprocess.pte"
195196
with open(output_file, "wb") as file:
196197
exec_prog.write_to_file(file)
197198

198199
logging.debug("Done")
199200

200201

201202
def main():
202-
export_processor()
203+
parser = argparse.ArgumentParser(
204+
description="Export WhisperAudioProcessor to ExecutorTorch"
205+
)
206+
parser.add_argument(
207+
"--feature_size",
208+
type=int,
209+
default=80,
210+
help="The feature dimension of the extracted features",
211+
)
212+
parser.add_argument(
213+
"--sampling_rate",
214+
type=int,
215+
default=16000,
216+
help="The sampling rate at which audio files should be digitalized (Hz)",
217+
)
218+
parser.add_argument(
219+
"--hop_length",
220+
type=int,
221+
default=160,
222+
help="Length of overlapping windows for STFT",
223+
)
224+
parser.add_argument(
225+
"--chunk_length",
226+
type=int,
227+
default=30,
228+
help="Maximum number of chunks of sampling_rate samples",
229+
)
230+
parser.add_argument(
231+
"--n_fft", type=int, default=400, help="Size of the Fourier transform"
232+
)
233+
parser.add_argument(
234+
"--output_file",
235+
type=str,
236+
default="whisper_preprocess.pte",
237+
help="Output file path for the exported model",
238+
)
239+
240+
args = parser.parse_args()
241+
242+
model = WhisperAudioProcessor(
243+
feature_size=args.feature_size,
244+
sampling_rate=args.sampling_rate,
245+
hop_length=args.hop_length,
246+
chunk_length=args.chunk_length,
247+
n_fft=args.n_fft,
248+
)
249+
250+
export_processor(model, args.output_file)
203251

204252

205253
if __name__ == "__main__":

0 commit comments

Comments
 (0)