Skip to content

Commit 624b38e

Browse files
authored
Whisper audio processor
Differential Revision: D80215714 Pull Request resolved: #13538
1 parent 9359481 commit 624b38e

File tree

2 files changed

+208
-0
lines changed

2 files changed

+208
-0
lines changed

extension/audio/TARGETS

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
3+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
4+
5+
oncall("executorch")
6+
7+
python_library(
8+
name = "mel_spectrogram_lib",
9+
srcs = ["mel_spectrogram.py"],
10+
deps = [
11+
"//caffe2:torch",
12+
"//executorch/devtools/backend_debug:delegation_info",
13+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
14+
"//executorch/runtime:runtime",
15+
"fbsource//third-party/pypi/datasets:datasets",
16+
"fbsource//third-party/pypi/transformers:transformers",
17+
"fbsource//third-party/pypi/librosa:librosa",
18+
"fbsource//third-party/pypi/soundfile:soundfile"
19+
]
20+
)
21+
22+
python_binary(
23+
name = "mel_spectrogram",
24+
main_module = "executorch.extension.audio.mel_spectrogram",
25+
deps = [
26+
":mel_spectrogram_lib",
27+
],
28+
)

extension/audio/mel_spectrogram.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
13+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
14+
15+
from executorch.exir import (
16+
EdgeCompileConfig,
17+
EdgeProgramManager,
18+
to_edge_transform_and_lower,
19+
)
20+
21+
from torch.export import Dim, export, ExportedProgram
22+
23+
24+
class WhisperAudioProcessor(nn.Module):
25+
"""
26+
Computes Mel spectrograms from mono audio input.
27+
Same as HuggingFace WhisperFeatureExtractor, but implemented in PyTorch
28+
"""
29+
30+
def __init__(
31+
self,
32+
feature_size=80,
33+
sampling_rate=16000,
34+
hop_length=160,
35+
chunk_length=30,
36+
n_fft=400,
37+
padding_value=0.0,
38+
):
39+
super().__init__()
40+
self.feature_size = feature_size
41+
self.sampling_rate = sampling_rate
42+
self.padding_value = padding_value
43+
44+
self.n_fft = n_fft
45+
self.hop_length = hop_length
46+
self.chunk_length = chunk_length
47+
self.n_samples = chunk_length * sampling_rate
48+
self.nb_max_frames = self.n_samples // hop_length
49+
self.sampling_rate = sampling_rate
50+
self.mel_filters = self.get_mel_filters(
51+
sampling_rate, n_fft, n_mels=feature_size
52+
)
53+
54+
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=torch.float32):
55+
# Initialize the weights
56+
n_mels = int(n_mels)
57+
weights = torch.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
58+
59+
# Center freqs of each FFT bin
60+
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr, dtype=dtype)
61+
62+
# 'Center freqs' of mel bands - uniformly spaced between limits
63+
min_mel = 0.0
64+
max_mel = 45.245640471924965
65+
66+
mels = torch.linspace(min_mel, max_mel, n_mels + 2, dtype=dtype)
67+
68+
# Fill in the linear scale
69+
f_min = 0.0
70+
f_sp = 200.0 / 3
71+
freqs = f_min + f_sp * mels
72+
73+
# And now the nonlinear scale
74+
min_log_hz = 1000.0 # beginning of log region (Hz)
75+
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
76+
logstep = (
77+
torch.log(torch.tensor(6.4, dtype=dtype)) / 27.0
78+
) # step size for log region
79+
80+
# If we have vector data, vectorize
81+
log_t = mels >= min_log_mel
82+
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
83+
84+
mel_f = freqs
85+
86+
fdiff = torch.diff(mel_f)
87+
ramps = torch.subtract(mel_f.unsqueeze(1), fftfreqs.unsqueeze(0))
88+
89+
for i in range(n_mels):
90+
# lower and upper slopes for all bins
91+
lower = -ramps[i] / fdiff[i]
92+
upper = ramps[i + 2] / fdiff[i + 1]
93+
94+
# .. then intersect them with each other and zero
95+
weights[i] = torch.maximum(
96+
torch.tensor(0.0, dtype=dtype), torch.minimum(lower, upper)
97+
)
98+
99+
# Slaney-style mel is scaled to be approx constant energy per channel
100+
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
101+
weights *= enorm[:, None]
102+
103+
return weights
104+
105+
def forward(self, waveform):
106+
waveform = F.pad(
107+
waveform,
108+
(0, self.n_samples - waveform.shape[0] - 1),
109+
mode="constant",
110+
value=0,
111+
)
112+
window = 0.5 * (
113+
1
114+
- torch.cos(
115+
2
116+
* torch.pi
117+
* torch.linspace(0, self.n_fft - 1, self.n_fft, dtype=torch.float32)
118+
/ self.n_fft
119+
)
120+
)
121+
# Ideally we should do instead
122+
# window = torch.hann_window(self.n_fft)
123+
# but this is not currently supported when lowering
124+
# torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4)
125+
stft = torch.stft(
126+
waveform,
127+
n_fft=self.n_fft,
128+
hop_length=self.hop_length,
129+
window=window,
130+
center=True,
131+
return_complex=True,
132+
)
133+
magnitudes = torch.abs(stft) ** 2
134+
135+
mel_spec = self.mel_filters @ magnitudes
136+
137+
log_spec = torch.log10(torch.clamp(mel_spec, min=1e-10))
138+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
139+
log_spec = (log_spec + 4.0) / 4.0
140+
141+
return log_spec.unsqueeze(0)
142+
143+
144+
def export_processor():
145+
model = WhisperAudioProcessor()
146+
audio_tensor = torch.randn(480000)
147+
chunk_tensor = audio_tensor[:93680]
148+
with torch.no_grad():
149+
# export. What is the min of waveforms?
150+
dim = Dim("waveform", min=1600, max=audio_tensor.size(0))
151+
ep: ExportedProgram = export(
152+
model, (chunk_tensor,), dynamic_shapes={"waveform": {0: dim}}, strict=True
153+
)
154+
logging.debug(ep)
155+
156+
# to edge
157+
edge: EdgeProgramManager = to_edge_transform_and_lower(
158+
ep,
159+
partitioner=[XnnpackPartitioner()],
160+
compile_config=EdgeCompileConfig(
161+
_check_ir_validity=False,
162+
),
163+
)
164+
logging.debug(edge.exported_program())
165+
166+
# to executorch
167+
exec_prog = edge.to_executorch()
168+
output_file = "whisper_preprocess.pte"
169+
with open(output_file, "wb") as file:
170+
exec_prog.write_to_file(file)
171+
172+
logging.debug("Done")
173+
174+
175+
def main():
176+
export_processor()
177+
178+
179+
if __name__ == "__main__":
180+
main()

0 commit comments

Comments
 (0)