Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 116 additions & 13 deletions mlx_audio/tts/models/kokoro/istftnet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import time
from typing import List, Optional, Tuple, Union

import librosa
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from scipy.signal import get_window

from ..base import check_array_shape
from ..interpolate import interpolate
Expand Down Expand Up @@ -398,6 +395,119 @@ def __call__(self, x: mx.array, s: mx.array) -> mx.array:
return x


def mlx_stft(
x,
n_fft=800,
hop_length=None,
win_length=None,
window="hann",
center=True,
pad_mode="reflect",
):
if hop_length is None:
hop_length = n_fft // 4
if win_length is None:
win_length = n_fft

if isinstance(window, str):
if window.lower() == "hann":
w = mx.array(np.hanning(win_length + 1)[:-1])
else:
raise ValueError(
f"Only 'hann' (string) is supported for window, not {window!r}"
)
else:
w = window

if w.shape[0] < n_fft:
pad_size = n_fft - w.shape[0]
w = mx.concatenate([w, mx.zeros((pad_size,))], axis=0)

def _pad(x, padding, pad_mode="reflect"):
if pad_mode == "constant":
return mx.pad(x, [(padding, padding)])
elif pad_mode == "reflect":
prefix = x[1 : padding + 1][::-1]
suffix = x[-(padding + 1) : -1][::-1]
return mx.concatenate([prefix, x, suffix])
else:
raise ValueError(f"Invalid pad_mode {pad_mode}")

x = mx.array(x)

if center:
x = _pad(x, n_fft // 2, pad_mode)

num_frames = 1 + (x.shape[0] - n_fft) // hop_length
if num_frames <= 0:
raise ValueError(
f"Input is too short (length={x.shape[0]}) for n_fft={n_fft} with "
f"hop_length={hop_length} and center={center}."
)

shape = (num_frames, n_fft)
strides = (hop_length, 1)
frames = mx.as_strided(x, shape=shape, strides=strides)
spec = mx.fft.rfft(frames * w)

return spec.transpose(1, 0)


def mlx_istft(
x,
hop_length=None,
win_length=None,
window="hann",
center=True,
length=None,
):
if hop_length is None:
hop_length = win_length // 4
if win_length is None:
win_length = (x.shape[1] - 1) * 2

if isinstance(window, str):
if window.lower() == "hann":
w = mx.array(np.hanning(win_length + 1)[:-1])
else:
raise ValueError(
f"Only 'hann' (string) is supported for window, not {window!r}"
)
else:
w = window

if w.shape[0] < win_length:
w = mx.concatenate([w, mx.zeros((win_length - w.shape[0],))], axis=0)

x = mx.array(x).transpose(1, 0)
t = (x.shape[0] - 1) * hop_length + win_length
reconstructed = mx.zeros(t)
window_sum = mx.zeros(t)

for i in range(x.shape[0]):
# inverse FFT of each frame
frame_time = mx.fft.irfft(x[i])

# get the position in the time-domain signal to add the frame
start = i * hop_length
end = start + win_length

# overlap-add the inverse transformed frame, scaled by the window
reconstructed[start:end] += frame_time * w
window_sum[start:end] += w**2

# normalize by the sum of the window values
reconstructed = mx.where(window_sum != 0, reconstructed / window_sum, reconstructed)

if center and length is None:
reconstructed = reconstructed[win_length // 2 : -win_length // 2]

if length is not None:
reconstructed = reconstructed[:length]

return reconstructed


class MLXSTFT:
def __init__(
self, filter_length=800, hop_length=200, win_length=800, window="hann"
Expand All @@ -406,14 +516,7 @@ def __init__(
self.hop_length = hop_length
self.win_length = win_length

# Get the window exactly as PyTorch does
self.window = get_window(window, win_length, fftbins=True).astype(np.float32)

# Pad window if needed
if win_length < filter_length:
pad_start = (filter_length - win_length) // 2
pad_end = filter_length - win_length - pad_start
self.window = np.pad(self.window, (pad_start, pad_end))
self.window = window

def transform(self, input_data):
# Convert to numpy and ensure 2D
Expand All @@ -426,7 +529,7 @@ def transform(self, input_data):

for batch_idx in range(audio_np.shape[0]):
# Compute STFT using librosa
stft = librosa.stft(
stft = mlx_stft(
audio_np[batch_idx],
n_fft=self.filter_length,
hop_length=self.hop_length,
Expand Down Expand Up @@ -464,7 +567,7 @@ def inverse(self, magnitude, phase):
stft = magnitude_np[batch_idx] * np.exp(1j * phase_cont)

# Inverse STFT using librosa
audio = librosa.istft(
audio = mlx_istft(
stft,
hop_length=self.hop_length,
win_length=self.win_length,
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ torch>=2.5.1
transformers>=4.49.0
sentencepiece>=0.2.0
huggingface_hub>=0.27.0
librosa>=0.10.2.post1
soundfile>=0.13.1