Skip to content

Commit 233063c

Browse files
authored
Mel spectrogram output stacking along batch dim (#14275)
Allows processing of >30 audio for Voxtral. Differential Revision: D81798729
1 parent 1c4f2e4 commit 233063c

File tree

1 file changed

+35
-10
lines changed

1 file changed

+35
-10
lines changed

extension/audio/mel_spectrogram.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@
1212
import torch.nn.functional as F
1313

1414
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
15-
1615
from executorch.exir import (
1716
EdgeCompileConfig,
1817
EdgeProgramManager,
1918
to_edge_transform_and_lower,
2019
)
21-
22-
from torch.export import Dim, export, ExportedProgram
20+
from torch.export import Dim
2321

2422

2523
class WhisperAudioProcessor(nn.Module):
@@ -51,6 +49,8 @@ def __init__(
5149
chunk_length: int = 30,
5250
n_fft: int = 400,
5351
padding_value: float = 0.0,
52+
max_audio_len: int = 600,
53+
stack_output: bool = False,
5454
) -> None:
5555
super().__init__()
5656
self.feature_size = feature_size
@@ -66,6 +66,8 @@ def __init__(
6666
self.mel_filters = self.get_mel_filters(
6767
sampling_rate, n_fft, n_mels=feature_size
6868
)
69+
self.max_audio_len = max_audio_len
70+
self.stack_output = stack_output
6971

7072
def get_mel_filters(
7173
self, sr: int, n_fft: int, n_mels: int = 128, dtype: torch.dtype = torch.float32
@@ -137,6 +139,7 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
137139
mode="constant",
138140
value=self.padding_value,
139141
)
142+
140143
# Ideally we should do:
141144
# window = torch.hann_window(self.n_fft)
142145
# but this is not currently supported when lowering.
@@ -166,18 +169,27 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
166169
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
167170
log_spec = (log_spec + 4.0) / 4.0
168171

169-
return log_spec.unsqueeze(0)
172+
if self.stack_output:
173+
log_spec = log_spec.reshape(self.feature_size, -1, self.nb_max_frames)
174+
log_spec = log_spec.transpose(0, 1)
175+
return log_spec
176+
else:
177+
return log_spec.unsqueeze(0)
170178

171179

172180
def export_processor(model=None, output_file="whisper_preprocess.pte"):
173181
if model is None:
174182
model = WhisperAudioProcessor()
175-
audio_tensor = torch.randn(480000)
176-
chunk_tensor = audio_tensor[:93680]
177-
with torch.no_grad():
178-
dim = Dim("waveform", min=1600, max=audio_tensor.size(0) * 10) # 10 chunks max
179-
ep: ExportedProgram = export(
180-
model, (chunk_tensor,), dynamic_shapes={"waveform": {0: dim}}, strict=True
183+
184+
audio_tensor = torch.randn(93680)
185+
shapes_collection = torch.export.ShapesCollection()
186+
max_n_chunks = int(model.max_audio_len * model.n_samples)
187+
shapes_collection[audio_tensor] = {0: Dim.DYNAMIC(max=max_n_chunks)}
188+
with torch.no_grad(), torch.fx.experimental._config.patch(
189+
backed_size_oblivious=True
190+
):
191+
ep = torch.export.export(
192+
model, (audio_tensor,), dynamic_shapes=shapes_collection, strict=True
181193
)
182194
logging.debug(ep)
183195

@@ -236,6 +248,17 @@ def main():
236248
default="whisper_preprocess.pte",
237249
help="Output file path for the exported model",
238250
)
251+
parser.add_argument(
252+
"--max_audio_len",
253+
type=int,
254+
default=600,
255+
help="Max audio length that can be processed, in seconds.",
256+
)
257+
parser.add_argument(
258+
"--stack_output",
259+
action="store_true",
260+
help="Whether to stack output along the batch dimension, one per chunk. Used by models such as Voxtral, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/voxtral/processing_voxtral.py#L94 for more information.",
261+
)
239262

240263
args = parser.parse_args()
241264

@@ -245,6 +268,8 @@ def main():
245268
hop_length=args.hop_length,
246269
chunk_length=args.chunk_length,
247270
n_fft=args.n_fft,
271+
max_audio_len=args.max_audio_len,
272+
stack_output=args.stack_output,
248273
)
249274

250275
export_processor(model, args.output_file)

0 commit comments

Comments
 (0)