Skip to content

Commit 5f3cf93

Browse files
authored
Add Vocos neural audio codec (#48)
1 parent b378ea8 commit 5f3cf93

File tree

6 files changed

+598
-1
lines changed

6 files changed

+598
-1
lines changed

mlx_audio/codec/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .models import Mimi
1+
from .models import Encodec, Mimi, Vocos

mlx_audio/codec/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .encodec import Encodec
22
from .mimi import Mimi
3+
from .vocos import Vocos
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .vocos import Vocos
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
from __future__ import annotations
2+
3+
import math
4+
from functools import lru_cache
5+
from typing import Optional
6+
7+
import mlx.core as mx
8+
9+
10+
@lru_cache(maxsize=None)
11+
def mel_filters(
12+
sample_rate: int,
13+
n_fft: int,
14+
n_mels: int,
15+
f_min: float = 0,
16+
f_max: Optional[float] = None,
17+
norm: Optional[str] = None,
18+
mel_scale: str = "htk",
19+
) -> mx.array:
20+
def hz_to_mel(freq, mel_scale="htk"):
21+
if mel_scale == "htk":
22+
return 2595.0 * math.log10(1.0 + freq / 700.0)
23+
24+
# slaney scale
25+
f_min, f_sp = 0.0, 200.0 / 3
26+
mels = (freq - f_min) / f_sp
27+
min_log_hz = 1000.0
28+
min_log_mel = (min_log_hz - f_min) / f_sp
29+
logstep = math.log(6.4) / 27.0
30+
if freq >= min_log_hz:
31+
mels = min_log_mel + math.log(freq / min_log_hz) / logstep
32+
return mels
33+
34+
def mel_to_hz(mels, mel_scale="htk"):
35+
if mel_scale == "htk":
36+
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
37+
38+
# slaney scale
39+
f_min, f_sp = 0.0, 200.0 / 3
40+
freqs = f_min + f_sp * mels
41+
min_log_hz = 1000.0
42+
min_log_mel = (min_log_hz - f_min) / f_sp
43+
logstep = math.log(6.4) / 27.0
44+
log_t = mels >= min_log_mel
45+
freqs[log_t] = min_log_hz * mx.exp(logstep * (mels[log_t] - min_log_mel))
46+
return freqs
47+
48+
f_max = f_max or sample_rate / 2
49+
50+
# generate frequency points
51+
52+
n_freqs = n_fft // 2 + 1
53+
all_freqs = mx.linspace(0, sample_rate // 2, n_freqs)
54+
55+
# convert frequencies to mel and back to hz
56+
57+
m_min = hz_to_mel(f_min, mel_scale)
58+
m_max = hz_to_mel(f_max, mel_scale)
59+
m_pts = mx.linspace(m_min, m_max, n_mels + 2)
60+
f_pts = mel_to_hz(m_pts, mel_scale)
61+
62+
# compute slopes for filterbank
63+
64+
f_diff = f_pts[1:] - f_pts[:-1]
65+
slopes = mx.expand_dims(f_pts, 0) - mx.expand_dims(all_freqs, 1)
66+
67+
# calculate overlapping triangular filters
68+
69+
down_slopes = (-slopes[:, :-2]) / f_diff[:-1]
70+
up_slopes = slopes[:, 2:] / f_diff[1:]
71+
filterbank = mx.maximum(
72+
mx.zeros_like(down_slopes), mx.minimum(down_slopes, up_slopes)
73+
)
74+
75+
if norm == "slaney":
76+
enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
77+
filterbank *= mx.expand_dims(enorm, 0)
78+
79+
filterbank = filterbank.moveaxis(0, 1)
80+
return filterbank
81+
82+
83+
@lru_cache(maxsize=None)
84+
def hanning(size):
85+
return mx.array(
86+
[0.5 * (1 - math.cos(2 * math.pi * n / (size - 1))) for n in range(size)]
87+
)
88+
89+
90+
def stft(x, window, nperseg=256, noverlap=None, nfft=None, pad_mode="constant"):
91+
if nfft is None:
92+
nfft = nperseg
93+
if noverlap is None:
94+
noverlap = nfft // 4
95+
96+
def _pad(x, padding, pad_mode="constant"):
97+
if pad_mode == "constant":
98+
return mx.pad(x, [(padding, padding)])
99+
elif pad_mode == "reflect":
100+
prefix = x[1 : padding + 1][::-1]
101+
suffix = x[-(padding + 1) : -1][::-1]
102+
return mx.concatenate([prefix, x, suffix])
103+
else:
104+
raise ValueError(f"Invalid pad_mode {pad_mode}")
105+
106+
padding = nperseg // 2
107+
x = _pad(x, padding, pad_mode)
108+
109+
strides = [noverlap, 1]
110+
t = (x.size - nperseg + noverlap) // noverlap
111+
shape = [t, nfft]
112+
x = mx.as_strided(x, shape=shape, strides=strides)
113+
return mx.fft.rfft(x * window)
114+
115+
116+
def istft(x, window, nperseg=256, noverlap=None, nfft=None):
117+
if nfft is None:
118+
nfft = nperseg
119+
if noverlap is None:
120+
noverlap = nfft // 4
121+
122+
t = (x.shape[0] - 1) * noverlap + nperseg
123+
reconstructed = mx.zeros(t)
124+
window_sum = mx.zeros(t)
125+
126+
for i in range(x.shape[0]):
127+
# inverse FFT of each frame
128+
frame_time = mx.fft.irfft(x[i])
129+
130+
# get the position in the time-domain signal to add the frame
131+
start = i * noverlap
132+
end = start + nperseg
133+
134+
# overlap-add the inverse transformed frame, scaled by the window
135+
reconstructed[start:end] += frame_time * window
136+
window_sum[start:end] += window
137+
138+
# normalize by the sum of the window values
139+
reconstructed = mx.where(window_sum != 0, reconstructed / window_sum, reconstructed)
140+
141+
return reconstructed
142+
143+
144+
def log_mel_spectrogram(
145+
audio: mx.array,
146+
sample_rate: int = 24_000,
147+
n_mels: int = 100,
148+
n_fft: int = 1024,
149+
hop_length: int = 256,
150+
padding: int = 0,
151+
):
152+
if not isinstance(audio, mx.array):
153+
audio = mx.array(audio)
154+
155+
if padding > 0:
156+
audio = mx.pad(audio, (0, padding))
157+
158+
freqs = stft(audio, hanning(n_fft), nperseg=n_fft, noverlap=hop_length)
159+
magnitudes = freqs[:-1, :].abs()
160+
filters = mel_filters(
161+
sample_rate=sample_rate,
162+
n_fft=n_fft,
163+
n_mels=n_mels,
164+
norm=None,
165+
mel_scale="htk",
166+
)
167+
mel_spec = magnitudes @ filters.T
168+
log_spec = mx.maximum(mel_spec, 1e-5).log()
169+
return mx.expand_dims(log_spec, axis=0)

0 commit comments

Comments
 (0)