|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import math |
| 4 | +from functools import lru_cache |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +import mlx.core as mx |
| 8 | + |
| 9 | + |
| 10 | +@lru_cache(maxsize=None) |
| 11 | +def mel_filters( |
| 12 | + sample_rate: int, |
| 13 | + n_fft: int, |
| 14 | + n_mels: int, |
| 15 | + f_min: float = 0, |
| 16 | + f_max: Optional[float] = None, |
| 17 | + norm: Optional[str] = None, |
| 18 | + mel_scale: str = "htk", |
| 19 | +) -> mx.array: |
| 20 | + def hz_to_mel(freq, mel_scale="htk"): |
| 21 | + if mel_scale == "htk": |
| 22 | + return 2595.0 * math.log10(1.0 + freq / 700.0) |
| 23 | + |
| 24 | + # slaney scale |
| 25 | + f_min, f_sp = 0.0, 200.0 / 3 |
| 26 | + mels = (freq - f_min) / f_sp |
| 27 | + min_log_hz = 1000.0 |
| 28 | + min_log_mel = (min_log_hz - f_min) / f_sp |
| 29 | + logstep = math.log(6.4) / 27.0 |
| 30 | + if freq >= min_log_hz: |
| 31 | + mels = min_log_mel + math.log(freq / min_log_hz) / logstep |
| 32 | + return mels |
| 33 | + |
| 34 | + def mel_to_hz(mels, mel_scale="htk"): |
| 35 | + if mel_scale == "htk": |
| 36 | + return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) |
| 37 | + |
| 38 | + # slaney scale |
| 39 | + f_min, f_sp = 0.0, 200.0 / 3 |
| 40 | + freqs = f_min + f_sp * mels |
| 41 | + min_log_hz = 1000.0 |
| 42 | + min_log_mel = (min_log_hz - f_min) / f_sp |
| 43 | + logstep = math.log(6.4) / 27.0 |
| 44 | + log_t = mels >= min_log_mel |
| 45 | + freqs[log_t] = min_log_hz * mx.exp(logstep * (mels[log_t] - min_log_mel)) |
| 46 | + return freqs |
| 47 | + |
| 48 | + f_max = f_max or sample_rate / 2 |
| 49 | + |
| 50 | + # generate frequency points |
| 51 | + |
| 52 | + n_freqs = n_fft // 2 + 1 |
| 53 | + all_freqs = mx.linspace(0, sample_rate // 2, n_freqs) |
| 54 | + |
| 55 | + # convert frequencies to mel and back to hz |
| 56 | + |
| 57 | + m_min = hz_to_mel(f_min, mel_scale) |
| 58 | + m_max = hz_to_mel(f_max, mel_scale) |
| 59 | + m_pts = mx.linspace(m_min, m_max, n_mels + 2) |
| 60 | + f_pts = mel_to_hz(m_pts, mel_scale) |
| 61 | + |
| 62 | + # compute slopes for filterbank |
| 63 | + |
| 64 | + f_diff = f_pts[1:] - f_pts[:-1] |
| 65 | + slopes = mx.expand_dims(f_pts, 0) - mx.expand_dims(all_freqs, 1) |
| 66 | + |
| 67 | + # calculate overlapping triangular filters |
| 68 | + |
| 69 | + down_slopes = (-slopes[:, :-2]) / f_diff[:-1] |
| 70 | + up_slopes = slopes[:, 2:] / f_diff[1:] |
| 71 | + filterbank = mx.maximum( |
| 72 | + mx.zeros_like(down_slopes), mx.minimum(down_slopes, up_slopes) |
| 73 | + ) |
| 74 | + |
| 75 | + if norm == "slaney": |
| 76 | + enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) |
| 77 | + filterbank *= mx.expand_dims(enorm, 0) |
| 78 | + |
| 79 | + filterbank = filterbank.moveaxis(0, 1) |
| 80 | + return filterbank |
| 81 | + |
| 82 | + |
| 83 | +@lru_cache(maxsize=None) |
| 84 | +def hanning(size): |
| 85 | + return mx.array( |
| 86 | + [0.5 * (1 - math.cos(2 * math.pi * n / (size - 1))) for n in range(size)] |
| 87 | + ) |
| 88 | + |
| 89 | + |
| 90 | +def stft(x, window, nperseg=256, noverlap=None, nfft=None, pad_mode="constant"): |
| 91 | + if nfft is None: |
| 92 | + nfft = nperseg |
| 93 | + if noverlap is None: |
| 94 | + noverlap = nfft // 4 |
| 95 | + |
| 96 | + def _pad(x, padding, pad_mode="constant"): |
| 97 | + if pad_mode == "constant": |
| 98 | + return mx.pad(x, [(padding, padding)]) |
| 99 | + elif pad_mode == "reflect": |
| 100 | + prefix = x[1 : padding + 1][::-1] |
| 101 | + suffix = x[-(padding + 1) : -1][::-1] |
| 102 | + return mx.concatenate([prefix, x, suffix]) |
| 103 | + else: |
| 104 | + raise ValueError(f"Invalid pad_mode {pad_mode}") |
| 105 | + |
| 106 | + padding = nperseg // 2 |
| 107 | + x = _pad(x, padding, pad_mode) |
| 108 | + |
| 109 | + strides = [noverlap, 1] |
| 110 | + t = (x.size - nperseg + noverlap) // noverlap |
| 111 | + shape = [t, nfft] |
| 112 | + x = mx.as_strided(x, shape=shape, strides=strides) |
| 113 | + return mx.fft.rfft(x * window) |
| 114 | + |
| 115 | + |
| 116 | +def istft(x, window, nperseg=256, noverlap=None, nfft=None): |
| 117 | + if nfft is None: |
| 118 | + nfft = nperseg |
| 119 | + if noverlap is None: |
| 120 | + noverlap = nfft // 4 |
| 121 | + |
| 122 | + t = (x.shape[0] - 1) * noverlap + nperseg |
| 123 | + reconstructed = mx.zeros(t) |
| 124 | + window_sum = mx.zeros(t) |
| 125 | + |
| 126 | + for i in range(x.shape[0]): |
| 127 | + # inverse FFT of each frame |
| 128 | + frame_time = mx.fft.irfft(x[i]) |
| 129 | + |
| 130 | + # get the position in the time-domain signal to add the frame |
| 131 | + start = i * noverlap |
| 132 | + end = start + nperseg |
| 133 | + |
| 134 | + # overlap-add the inverse transformed frame, scaled by the window |
| 135 | + reconstructed[start:end] += frame_time * window |
| 136 | + window_sum[start:end] += window |
| 137 | + |
| 138 | + # normalize by the sum of the window values |
| 139 | + reconstructed = mx.where(window_sum != 0, reconstructed / window_sum, reconstructed) |
| 140 | + |
| 141 | + return reconstructed |
| 142 | + |
| 143 | + |
| 144 | +def log_mel_spectrogram( |
| 145 | + audio: mx.array, |
| 146 | + sample_rate: int = 24_000, |
| 147 | + n_mels: int = 100, |
| 148 | + n_fft: int = 1024, |
| 149 | + hop_length: int = 256, |
| 150 | + padding: int = 0, |
| 151 | +): |
| 152 | + if not isinstance(audio, mx.array): |
| 153 | + audio = mx.array(audio) |
| 154 | + |
| 155 | + if padding > 0: |
| 156 | + audio = mx.pad(audio, (0, padding)) |
| 157 | + |
| 158 | + freqs = stft(audio, hanning(n_fft), nperseg=n_fft, noverlap=hop_length) |
| 159 | + magnitudes = freqs[:-1, :].abs() |
| 160 | + filters = mel_filters( |
| 161 | + sample_rate=sample_rate, |
| 162 | + n_fft=n_fft, |
| 163 | + n_mels=n_mels, |
| 164 | + norm=None, |
| 165 | + mel_scale="htk", |
| 166 | + ) |
| 167 | + mel_spec = magnitudes @ filters.T |
| 168 | + log_spec = mx.maximum(mel_spec, 1e-5).log() |
| 169 | + return mx.expand_dims(log_spec, axis=0) |
0 commit comments