1212import torch .nn .functional as F
1313
1414from executorch .backends .xnnpack .partition .xnnpack_partitioner import XnnpackPartitioner
15-
1615from executorch .exir import (
1716 EdgeCompileConfig ,
1817 EdgeProgramManager ,
1918 to_edge_transform_and_lower ,
2019)
21-
22- from torch .export import Dim , export , ExportedProgram
20+ from torch .export import Dim
2321
2422
2523class WhisperAudioProcessor (nn .Module ):
@@ -51,6 +49,8 @@ def __init__(
5149 chunk_length : int = 30 ,
5250 n_fft : int = 400 ,
5351 padding_value : float = 0.0 ,
52+ max_audio_len : int = 600 ,
53+ stack_output : bool = False ,
5454 ) -> None :
5555 super ().__init__ ()
5656 self .feature_size = feature_size
@@ -66,6 +66,8 @@ def __init__(
6666 self .mel_filters = self .get_mel_filters (
6767 sampling_rate , n_fft , n_mels = feature_size
6868 )
69+ self .max_audio_len = max_audio_len
70+ self .stack_output = stack_output
6971
7072 def get_mel_filters (
7173 self , sr : int , n_fft : int , n_mels : int = 128 , dtype : torch .dtype = torch .float32
@@ -137,6 +139,7 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
137139 mode = "constant" ,
138140 value = self .padding_value ,
139141 )
142+
140143 # Ideally we should do:
141144 # window = torch.hann_window(self.n_fft)
142145 # but this is not currently supported when lowering.
@@ -166,18 +169,27 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
166169 log_spec = torch .maximum (log_spec , log_spec .max () - 8.0 )
167170 log_spec = (log_spec + 4.0 ) / 4.0
168171
169- return log_spec .unsqueeze (0 )
172+ if self .stack_output :
173+ log_spec = log_spec .reshape (self .feature_size , - 1 , self .nb_max_frames )
174+ log_spec = log_spec .transpose (0 , 1 )
175+ return log_spec
176+ else :
177+ return log_spec .unsqueeze (0 )
170178
171179
172180def export_processor (model = None , output_file = "whisper_preprocess.pte" ):
173181 if model is None :
174182 model = WhisperAudioProcessor ()
175- audio_tensor = torch .randn (480000 )
176- chunk_tensor = audio_tensor [:93680 ]
177- with torch .no_grad ():
178- dim = Dim ("waveform" , min = 1600 , max = audio_tensor .size (0 ) * 10 ) # 10 chunks max
179- ep : ExportedProgram = export (
180- model , (chunk_tensor ,), dynamic_shapes = {"waveform" : {0 : dim }}, strict = True
183+
184+ audio_tensor = torch .randn (93680 )
185+ shapes_collection = torch .export .ShapesCollection ()
186+ max_n_chunks = int (model .max_audio_len * model .n_samples )
187+ shapes_collection [audio_tensor ] = {0 : Dim .DYNAMIC (max = max_n_chunks )}
188+ with torch .no_grad (), torch .fx .experimental ._config .patch (
189+ backed_size_oblivious = True
190+ ):
191+ ep = torch .export .export (
192+ model , (audio_tensor ,), dynamic_shapes = shapes_collection , strict = True
181193 )
182194 logging .debug (ep )
183195
@@ -236,6 +248,17 @@ def main():
236248 default = "whisper_preprocess.pte" ,
237249 help = "Output file path for the exported model" ,
238250 )
251+ parser .add_argument (
252+ "--max_audio_len" ,
253+ type = int ,
254+ default = 600 ,
255+ help = "Max audio length that can be processed, in seconds." ,
256+ )
257+ parser .add_argument (
258+ "--stack_output" ,
259+ action = "store_true" ,
260+ help = "Whether to stack output along the batch dimension, one per chunk. Used by models such as Voxtral, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/voxtral/processing_voxtral.py#L94 for more information." ,
261+ )
239262
240263 args = parser .parse_args ()
241264
@@ -245,6 +268,8 @@ def main():
245268 hop_length = args .hop_length ,
246269 chunk_length = args .chunk_length ,
247270 n_fft = args .n_fft ,
271+ max_audio_len = args .max_audio_len ,
272+ stack_output = args .stack_output ,
248273 )
249274
250275 export_processor (model , args .output_file )
0 commit comments