|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import logging |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | +import torch.nn.functional as F |
| 12 | + |
| 13 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
| 14 | + |
| 15 | +from executorch.exir import ( |
| 16 | + EdgeCompileConfig, |
| 17 | + EdgeProgramManager, |
| 18 | + to_edge_transform_and_lower, |
| 19 | +) |
| 20 | + |
| 21 | +from torch.export import Dim, export, ExportedProgram |
| 22 | + |
| 23 | + |
| 24 | +class WhisperAudioProcessor(nn.Module): |
| 25 | + """ |
| 26 | + Computes Mel spectrograms from mono audio input. |
| 27 | + Same as HuggingFace WhisperFeatureExtractor, but implemented in PyTorch |
| 28 | + """ |
| 29 | + |
| 30 | + def __init__( |
| 31 | + self, |
| 32 | + feature_size=80, |
| 33 | + sampling_rate=16000, |
| 34 | + hop_length=160, |
| 35 | + chunk_length=30, |
| 36 | + n_fft=400, |
| 37 | + padding_value=0.0, |
| 38 | + ): |
| 39 | + super().__init__() |
| 40 | + self.feature_size = feature_size |
| 41 | + self.sampling_rate = sampling_rate |
| 42 | + self.padding_value = padding_value |
| 43 | + |
| 44 | + self.n_fft = n_fft |
| 45 | + self.hop_length = hop_length |
| 46 | + self.chunk_length = chunk_length |
| 47 | + self.n_samples = chunk_length * sampling_rate |
| 48 | + self.nb_max_frames = self.n_samples // hop_length |
| 49 | + self.sampling_rate = sampling_rate |
| 50 | + self.mel_filters = self.get_mel_filters( |
| 51 | + sampling_rate, n_fft, n_mels=feature_size |
| 52 | + ) |
| 53 | + |
| 54 | + def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=torch.float32): |
| 55 | + # Initialize the weights |
| 56 | + n_mels = int(n_mels) |
| 57 | + weights = torch.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) |
| 58 | + |
| 59 | + # Center freqs of each FFT bin |
| 60 | + fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr, dtype=dtype) |
| 61 | + |
| 62 | + # 'Center freqs' of mel bands - uniformly spaced between limits |
| 63 | + min_mel = 0.0 |
| 64 | + max_mel = 45.245640471924965 |
| 65 | + |
| 66 | + mels = torch.linspace(min_mel, max_mel, n_mels + 2, dtype=dtype) |
| 67 | + |
| 68 | + # Fill in the linear scale |
| 69 | + f_min = 0.0 |
| 70 | + f_sp = 200.0 / 3 |
| 71 | + freqs = f_min + f_sp * mels |
| 72 | + |
| 73 | + # And now the nonlinear scale |
| 74 | + min_log_hz = 1000.0 # beginning of log region (Hz) |
| 75 | + min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) |
| 76 | + logstep = ( |
| 77 | + torch.log(torch.tensor(6.4, dtype=dtype)) / 27.0 |
| 78 | + ) # step size for log region |
| 79 | + |
| 80 | + # If we have vector data, vectorize |
| 81 | + log_t = mels >= min_log_mel |
| 82 | + freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel)) |
| 83 | + |
| 84 | + mel_f = freqs |
| 85 | + |
| 86 | + fdiff = torch.diff(mel_f) |
| 87 | + ramps = torch.subtract(mel_f.unsqueeze(1), fftfreqs.unsqueeze(0)) |
| 88 | + |
| 89 | + for i in range(n_mels): |
| 90 | + # lower and upper slopes for all bins |
| 91 | + lower = -ramps[i] / fdiff[i] |
| 92 | + upper = ramps[i + 2] / fdiff[i + 1] |
| 93 | + |
| 94 | + # .. then intersect them with each other and zero |
| 95 | + weights[i] = torch.maximum( |
| 96 | + torch.tensor(0.0, dtype=dtype), torch.minimum(lower, upper) |
| 97 | + ) |
| 98 | + |
| 99 | + # Slaney-style mel is scaled to be approx constant energy per channel |
| 100 | + enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) |
| 101 | + weights *= enorm[:, None] |
| 102 | + |
| 103 | + return weights |
| 104 | + |
| 105 | + def forward(self, waveform): |
| 106 | + waveform = F.pad( |
| 107 | + waveform, |
| 108 | + (0, self.n_samples - waveform.shape[0] - 1), |
| 109 | + mode="constant", |
| 110 | + value=0, |
| 111 | + ) |
| 112 | + window = 0.5 * ( |
| 113 | + 1 |
| 114 | + - torch.cos( |
| 115 | + 2 |
| 116 | + * torch.pi |
| 117 | + * torch.linspace(0, self.n_fft - 1, self.n_fft, dtype=torch.float32) |
| 118 | + / self.n_fft |
| 119 | + ) |
| 120 | + ) |
| 121 | + # Ideally we should do instead |
| 122 | + # window = torch.hann_window(self.n_fft) |
| 123 | + # but this is not currently supported when lowering |
| 124 | + # torch.hann_window has slightly better numerics (worst discrepancy is <1e-5 instead of 1e-4) |
| 125 | + stft = torch.stft( |
| 126 | + waveform, |
| 127 | + n_fft=self.n_fft, |
| 128 | + hop_length=self.hop_length, |
| 129 | + window=window, |
| 130 | + center=True, |
| 131 | + return_complex=True, |
| 132 | + ) |
| 133 | + magnitudes = torch.abs(stft) ** 2 |
| 134 | + |
| 135 | + mel_spec = self.mel_filters @ magnitudes |
| 136 | + |
| 137 | + log_spec = torch.log10(torch.clamp(mel_spec, min=1e-10)) |
| 138 | + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) |
| 139 | + log_spec = (log_spec + 4.0) / 4.0 |
| 140 | + |
| 141 | + return log_spec.unsqueeze(0) |
| 142 | + |
| 143 | + |
| 144 | +def export_processor(): |
| 145 | + model = WhisperAudioProcessor() |
| 146 | + audio_tensor = torch.randn(480000) |
| 147 | + chunk_tensor = audio_tensor[:93680] |
| 148 | + with torch.no_grad(): |
| 149 | + # export. What is the min of waveforms? |
| 150 | + dim = Dim("waveform", min=1600, max=audio_tensor.size(0)) |
| 151 | + ep: ExportedProgram = export( |
| 152 | + model, (chunk_tensor,), dynamic_shapes={"waveform": {0: dim}}, strict=True |
| 153 | + ) |
| 154 | + logging.debug(ep) |
| 155 | + |
| 156 | + # to edge |
| 157 | + edge: EdgeProgramManager = to_edge_transform_and_lower( |
| 158 | + ep, |
| 159 | + partitioner=[XnnpackPartitioner()], |
| 160 | + compile_config=EdgeCompileConfig( |
| 161 | + _check_ir_validity=False, |
| 162 | + ), |
| 163 | + ) |
| 164 | + logging.debug(edge.exported_program()) |
| 165 | + |
| 166 | + # to executorch |
| 167 | + exec_prog = edge.to_executorch() |
| 168 | + output_file = "whisper_preprocess.pte" |
| 169 | + with open(output_file, "wb") as file: |
| 170 | + exec_prog.write_to_file(file) |
| 171 | + |
| 172 | + logging.debug("Done") |
| 173 | + |
| 174 | + |
| 175 | +def main(): |
| 176 | + export_processor() |
| 177 | + |
| 178 | + |
| 179 | +if __name__ == "__main__": |
| 180 | + main() |
0 commit comments