88import logging
99
1010import torch
11+ from torch .export import Dim
1112import torch .nn as nn
1213import torch .nn .functional as F
1314
1415from executorch .backends .xnnpack .partition .xnnpack_partitioner import XnnpackPartitioner
15-
1616from executorch .exir import (
1717 EdgeCompileConfig ,
1818 EdgeProgramManager ,
1919 to_edge_transform_and_lower ,
2020)
2121
22- from torch .export import Dim , export , ExportedProgram
23-
2422
2523class WhisperAudioProcessor (nn .Module ):
2624 r"""
@@ -51,6 +49,8 @@ def __init__(
5149 chunk_length : int = 30 ,
5250 n_fft : int = 400 ,
5351 padding_value : float = 0.0 ,
52+ max_audio_len : int = 600 ,
53+ stack_output : bool = False ,
5454 ) -> None :
5555 super ().__init__ ()
5656 self .feature_size = feature_size
@@ -66,6 +66,9 @@ def __init__(
6666 self .mel_filters = self .get_mel_filters (
6767 sampling_rate , n_fft , n_mels = feature_size
6868 )
69+ self .max_audio_len = max_audio_len
70+ self .max_n_chunks = int (max_audio_len / chunk_length )
71+ self .stack_output = stack_output
6972
7073 def get_mel_filters (
7174 self , sr : int , n_fft : int , n_mels : int = 128 , dtype : torch .dtype = torch .float32
@@ -131,12 +134,14 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
131134 [1, 80, 3000] with default options and 1 chunk
132135 """
133136 n_chunks = (waveform .shape [0 ] - 1 ) // self .n_samples + 1
137+ torch ._constrain_as_size (n_chunks , max = self .max_n_chunks ) # Explicitly sets the max bound, otherwise export complains about it being infinite.
134138 waveform = F .pad (
135139 waveform ,
136140 (0 , self .n_samples * n_chunks - waveform .shape [0 ]),
137141 mode = "constant" ,
138142 value = self .padding_value ,
139143 )
144+
140145 # Ideally we should do:
141146 # window = torch.hann_window(self.n_fft)
142147 # but this is not currently supported when lowering.
@@ -166,18 +171,24 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
166171 log_spec = torch .maximum (log_spec , log_spec .max () - 8.0 )
167172 log_spec = (log_spec + 4.0 ) / 4.0
168173
169- return log_spec .unsqueeze (0 )
174+ if self .stack_output :
175+ log_spec = log_spec .reshape (self .feature_size , - 1 , self .nb_max_frames )
176+ log_spec = log_spec .transpose (0 , 1 )
177+ return log_spec
178+ else :
179+ return log_spec .unsqueeze (0 )
170180
171181
172182def export_processor (model = None , output_file = "whisper_preprocess.pte" ):
173183 if model is None :
174184 model = WhisperAudioProcessor ()
175- audio_tensor = torch .randn (480000 )
176- chunk_tensor = audio_tensor [:93680 ]
177- with torch .no_grad ():
178- dim = Dim ("waveform" , min = 1600 , max = audio_tensor .size (0 ) * 10 ) # 10 chunks max
179- ep : ExportedProgram = export (
180- model , (chunk_tensor ,), dynamic_shapes = {"waveform" : {0 : dim }}, strict = True
185+
186+ audio_tensor = torch .randn (93680 )
187+ shapes_collection = torch .export .ShapesCollection ()
188+ shapes_collection [audio_tensor ] = {0 : Dim .DYNAMIC }
189+ with torch .no_grad (), torch .fx .experimental ._config .patch (backed_size_oblivious = True ):
190+ ep = torch .export .export (
191+ model , (audio_tensor ,), dynamic_shapes = shapes_collection , strict = True
181192 )
182193 logging .debug (ep )
183194
@@ -236,6 +247,17 @@ def main():
236247 default = "whisper_preprocess.pte" ,
237248 help = "Output file path for the exported model" ,
238249 )
250+ parser .add_argument (
251+ "--max_audio_len" ,
252+ type = int ,
253+ default = 600 ,
254+ help = "Max audio length that can be processed, in seconds."
255+ )
256+ parser .add_argument (
257+ "--stack_output" ,
258+ action = "store_true" ,
259+ help = "Whether to stack output along the batch dimension, one per chunk. Used by models such as Voxtral, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/voxtral/processing_voxtral.py#L94 for more information."
260+ )
239261
240262 args = parser .parse_args ()
241263
@@ -245,6 +267,8 @@ def main():
245267 hop_length = args .hop_length ,
246268 chunk_length = args .chunk_length ,
247269 n_fft = args .n_fft ,
270+ max_audio_len = args .max_audio_len ,
271+ stack_output = args .stack_output ,
248272 )
249273
250274 export_processor (model , args .output_file )
0 commit comments