Skip to content

Commit 18151e4

Browse files
authored
Added documentation and typing to WhisperAudioProcessor (#13661)
Differential Revision: D80907444
1 parent d760b77 commit 18151e4

File tree

2 files changed

+52
-14
lines changed

2 files changed

+52
-14
lines changed

extension/audio/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Audio Processing with ExecuTorch
2+
3+
The file `mel_spectrogram.py` contains the class `WhisperAudioProcessor`, a module which converts a mono waveform audio input (as a 1D tensor) into Mel spectrograms. It applies a Short-Time Fourier Transform (via torch.stft) and a Mel filterbank. It is equivalent to the `WhisperFeatureExtractor` class in HuggingFace Transformers, but is implemented in PyTorch instead of NumPy. `WhisperFeatureExtractor` is used for Whisper, Voxtral, Qwen2 audio and Qwen2.5 omni. For example, the output Mel spectrograms can be fed directly into the Whisper model (encoder+decoder) exported from HF Transformers.
4+
5+
Since `WhisperAudioProcessor` is written in PyTorch, we can export it with ExecuTorch and run it on device. The defaults for `WhisperAudioProcessor` are 16kHz audio and 80 Mel spectrogram bins and audio chunks of 30 sec.
6+
7+
Run it as a script
8+
9+
``` python mel_spectrogram.py ```
10+
11+
to export `WhisperFeatureExtractor` (with default constructor arguments) as `whisper_preprocess.pte`, which can run on device (on CPU).

extension/audio/mel_spectrogram.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,35 @@
2222

2323

2424
class WhisperAudioProcessor(nn.Module):
25-
"""
25+
r"""
2626
Computes Mel spectrograms from mono audio input.
2727
Same as HuggingFace WhisperFeatureExtractor, but implemented in PyTorch
28+
29+
Args:
30+
feature_size (`int`, defaults to 80):
31+
The feature dimension of the extracted features.
32+
sampling_rate (`int`, defaults to 16000):
33+
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
34+
hop_length (`int`, defaults to 160):
35+
Length of the overlaping windows for the STFT used to obtain the Mel Frequency coefficients.
36+
chunk_length (`int`, defaults to 30):
37+
The maximum number of chuncks of `sampling_rate` samples used to trim and pad longer or shorter audio
38+
sequences.
39+
n_fft (`int`, defaults to 400):
40+
Size of the Fourier transform.
41+
padding_value (`float`, *optional*, defaults to 0.0):
42+
Padding value used to pad the audio. Should correspond to silences.
2843
"""
2944

3045
def __init__(
3146
self,
32-
feature_size=80,
33-
sampling_rate=16000,
34-
hop_length=160,
35-
chunk_length=30,
36-
n_fft=400,
37-
padding_value=0.0,
38-
):
47+
feature_size: int = 80,
48+
sampling_rate: int = 16000,
49+
hop_length: int = 160,
50+
chunk_length: int = 30,
51+
n_fft: int = 400,
52+
padding_value: float = 0.0,
53+
) -> None:
3954
super().__init__()
4055
self.feature_size = feature_size
4156
self.sampling_rate = sampling_rate
@@ -51,7 +66,9 @@ def __init__(
5166
sampling_rate, n_fft, n_mels=feature_size
5267
)
5368

54-
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=torch.float32):
69+
def get_mel_filters(
70+
self, sr: int, n_fft: int, n_mels: int = 128, dtype: torch.dtype = torch.float32
71+
) -> torch.Tensor:
5572
# Initialize the weights
5673
n_mels = int(n_mels)
5774
weights = torch.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
@@ -97,17 +114,27 @@ def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=torch.float32):
97114
)
98115

99116
# Slaney-style mel is scaled to be approx constant energy per channel
100-
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
101-
weights *= enorm[:, None]
117+
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) # pyre-ignore[58]
118+
weights *= enorm[:, None] # pyre-ignore[16]
102119

103120
return weights
104121

105-
def forward(self, waveform):
122+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
123+
r"""
124+
Args:
125+
waveform (`torch.Tensor`): Mono waveform input, tensor of (dynamic) shape [num_samples],
126+
where num_samples < n_samples. n_samples is 480000 for 16kHz and chunk length 30
127+
128+
Returns:
129+
torch.Tensor: Output of fixed shape [1, feature_size, nb_max_frames]
130+
[1, 80, 3000] with default options
131+
"""
132+
# TODO: pad up to multiples of chunk_length (currently 1 chunk of 30 sec)
106133
waveform = F.pad(
107134
waveform,
108135
(0, self.n_samples - waveform.shape[0] - 1),
109136
mode="constant",
110-
value=0,
137+
value=self.padding_value,
111138
)
112139
window = 0.5 * (
113140
1
@@ -130,7 +157,7 @@ def forward(self, waveform):
130157
center=True,
131158
return_complex=True,
132159
)
133-
magnitudes = torch.abs(stft) ** 2
160+
magnitudes = torch.abs(stft) ** 2 # pyre-ignore[58]
134161

135162
mel_spec = self.mel_filters @ magnitudes
136163

0 commit comments

Comments
 (0)