Skip to content

Commit 5d0ce34

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Mel spectrogram output stacking along batch dim
Differential Revision: D81798729
1 parent 4c90a53 commit 5d0ce34

File tree

1 file changed

+34
-10
lines changed

1 file changed

+34
-10
lines changed

extension/audio/mel_spectrogram.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,17 @@
88
import logging
99

1010
import torch
11+
from torch.export import Dim
1112
import torch.nn as nn
1213
import torch.nn.functional as F
1314

1415
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
15-
1616
from executorch.exir import (
1717
EdgeCompileConfig,
1818
EdgeProgramManager,
1919
to_edge_transform_and_lower,
2020
)
2121

22-
from torch.export import Dim, export, ExportedProgram
23-
2422

2523
class WhisperAudioProcessor(nn.Module):
2624
r"""
@@ -51,6 +49,8 @@ def __init__(
5149
chunk_length: int = 30,
5250
n_fft: int = 400,
5351
padding_value: float = 0.0,
52+
max_audio_len: int = 600,
53+
stack_output: bool = False,
5454
) -> None:
5555
super().__init__()
5656
self.feature_size = feature_size
@@ -66,6 +66,9 @@ def __init__(
6666
self.mel_filters = self.get_mel_filters(
6767
sampling_rate, n_fft, n_mels=feature_size
6868
)
69+
self.max_audio_len = max_audio_len
70+
self.max_n_chunks = int(max_audio_len / chunk_length)
71+
self.stack_output = stack_output
6972

7073
def get_mel_filters(
7174
self, sr: int, n_fft: int, n_mels: int = 128, dtype: torch.dtype = torch.float32
@@ -131,12 +134,14 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
131134
[1, 80, 3000] with default options and 1 chunk
132135
"""
133136
n_chunks = (waveform.shape[0] - 1) // self.n_samples + 1
137+
torch._constrain_as_size(n_chunks, max=self.max_n_chunks) # Explicitly sets the max bound, otherwise export complains about it being infinite.
134138
waveform = F.pad(
135139
waveform,
136140
(0, self.n_samples * n_chunks - waveform.shape[0]),
137141
mode="constant",
138142
value=self.padding_value,
139143
)
144+
140145
# Ideally we should do:
141146
# window = torch.hann_window(self.n_fft)
142147
# but this is not currently supported when lowering.
@@ -166,18 +171,24 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
166171
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
167172
log_spec = (log_spec + 4.0) / 4.0
168173

169-
return log_spec.unsqueeze(0)
174+
if self.stack_output:
175+
log_spec = log_spec.reshape(self.feature_size, -1, self.nb_max_frames)
176+
log_spec = log_spec.transpose(0, 1)
177+
return log_spec
178+
else:
179+
return log_spec.unsqueeze(0)
170180

171181

172182
def export_processor(model=None, output_file="whisper_preprocess.pte"):
173183
if model is None:
174184
model = WhisperAudioProcessor()
175-
audio_tensor = torch.randn(480000)
176-
chunk_tensor = audio_tensor[:93680]
177-
with torch.no_grad():
178-
dim = Dim("waveform", min=1600, max=audio_tensor.size(0) * 10) # 10 chunks max
179-
ep: ExportedProgram = export(
180-
model, (chunk_tensor,), dynamic_shapes={"waveform": {0: dim}}, strict=True
185+
186+
audio_tensor = torch.randn(93680)
187+
shapes_collection = torch.export.ShapesCollection()
188+
shapes_collection[audio_tensor] = {0: Dim.DYNAMIC}
189+
with torch.no_grad(), torch.fx.experimental._config.patch(backed_size_oblivious=True):
190+
ep = torch.export.export(
191+
model, (audio_tensor,), dynamic_shapes=shapes_collection, strict=True
181192
)
182193
logging.debug(ep)
183194

@@ -236,6 +247,17 @@ def main():
236247
default="whisper_preprocess.pte",
237248
help="Output file path for the exported model",
238249
)
250+
parser.add_argument(
251+
"--max_audio_len",
252+
type=int,
253+
default=600,
254+
help="Max audio length that can be processed, in seconds."
255+
)
256+
parser.add_argument(
257+
"--stack_output",
258+
action="store_true",
259+
help="Whether to stack output along the batch dimension, one per chunk. Used by models such as Voxtral, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/voxtral/processing_voxtral.py#L94 for more information."
260+
)
239261

240262
args = parser.parse_args()
241263

@@ -245,6 +267,8 @@ def main():
245267
hop_length=args.hop_length,
246268
chunk_length=args.chunk_length,
247269
n_fft=args.n_fft,
270+
max_audio_len=args.max_audio_len,
271+
stack_output=args.stack_output,
248272
)
249273

250274
export_processor(model, args.output_file)

0 commit comments

Comments
 (0)