12
12
import torch .nn .functional as F
13
13
14
14
from executorch .backends .xnnpack .partition .xnnpack_partitioner import XnnpackPartitioner
15
-
16
15
from executorch .exir import (
17
16
EdgeCompileConfig ,
18
17
EdgeProgramManager ,
19
18
to_edge_transform_and_lower ,
20
19
)
21
-
22
- from torch .export import Dim , export , ExportedProgram
20
+ from torch .export import Dim
23
21
24
22
25
23
class WhisperAudioProcessor (nn .Module ):
@@ -51,6 +49,8 @@ def __init__(
51
49
chunk_length : int = 30 ,
52
50
n_fft : int = 400 ,
53
51
padding_value : float = 0.0 ,
52
+ max_audio_len : int = 600 ,
53
+ stack_output : bool = False ,
54
54
) -> None :
55
55
super ().__init__ ()
56
56
self .feature_size = feature_size
@@ -66,6 +66,8 @@ def __init__(
66
66
self .mel_filters = self .get_mel_filters (
67
67
sampling_rate , n_fft , n_mels = feature_size
68
68
)
69
+ self .max_audio_len = max_audio_len
70
+ self .stack_output = stack_output
69
71
70
72
def get_mel_filters (
71
73
self , sr : int , n_fft : int , n_mels : int = 128 , dtype : torch .dtype = torch .float32
@@ -137,6 +139,7 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
137
139
mode = "constant" ,
138
140
value = self .padding_value ,
139
141
)
142
+
140
143
# Ideally we should do:
141
144
# window = torch.hann_window(self.n_fft)
142
145
# but this is not currently supported when lowering.
@@ -166,18 +169,27 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
166
169
log_spec = torch .maximum (log_spec , log_spec .max () - 8.0 )
167
170
log_spec = (log_spec + 4.0 ) / 4.0
168
171
169
- return log_spec .unsqueeze (0 )
172
+ if self .stack_output :
173
+ log_spec = log_spec .reshape (self .feature_size , - 1 , self .nb_max_frames )
174
+ log_spec = log_spec .transpose (0 , 1 )
175
+ return log_spec
176
+ else :
177
+ return log_spec .unsqueeze (0 )
170
178
171
179
172
180
def export_processor (model = None , output_file = "whisper_preprocess.pte" ):
173
181
if model is None :
174
182
model = WhisperAudioProcessor ()
175
- audio_tensor = torch .randn (480000 )
176
- chunk_tensor = audio_tensor [:93680 ]
177
- with torch .no_grad ():
178
- dim = Dim ("waveform" , min = 1600 , max = audio_tensor .size (0 ) * 10 ) # 10 chunks max
179
- ep : ExportedProgram = export (
180
- model , (chunk_tensor ,), dynamic_shapes = {"waveform" : {0 : dim }}, strict = True
183
+
184
+ audio_tensor = torch .randn (93680 )
185
+ shapes_collection = torch .export .ShapesCollection ()
186
+ max_n_chunks = int (model .max_audio_len * model .n_samples )
187
+ shapes_collection [audio_tensor ] = {0 : Dim .DYNAMIC (max = max_n_chunks )}
188
+ with torch .no_grad (), torch .fx .experimental ._config .patch (
189
+ backed_size_oblivious = True
190
+ ):
191
+ ep = torch .export .export (
192
+ model , (audio_tensor ,), dynamic_shapes = shapes_collection , strict = True
181
193
)
182
194
logging .debug (ep )
183
195
@@ -236,6 +248,17 @@ def main():
236
248
default = "whisper_preprocess.pte" ,
237
249
help = "Output file path for the exported model" ,
238
250
)
251
+ parser .add_argument (
252
+ "--max_audio_len" ,
253
+ type = int ,
254
+ default = 600 ,
255
+ help = "Max audio length that can be processed, in seconds." ,
256
+ )
257
+ parser .add_argument (
258
+ "--stack_output" ,
259
+ action = "store_true" ,
260
+ help = "Whether to stack output along the batch dimension, one per chunk. Used by models such as Voxtral, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/voxtral/processing_voxtral.py#L94 for more information." ,
261
+ )
239
262
240
263
args = parser .parse_args ()
241
264
@@ -245,6 +268,8 @@ def main():
245
268
hop_length = args .hop_length ,
246
269
chunk_length = args .chunk_length ,
247
270
n_fft = args .n_fft ,
271
+ max_audio_len = args .max_audio_len ,
272
+ stack_output = args .stack_output ,
248
273
)
249
274
250
275
export_processor (model , args .output_file )
0 commit comments