Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlx_audio/tts/models/kokoro/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .kokoro import Model
from .kokoro import Model, ModelConfig
from .pipeline import KokoroPipeline

__all__ = ["KokoroPipeline", "Model"]
__all__ = ["KokoroPipeline", "Model", "ModelConfig"]
66 changes: 37 additions & 29 deletions mlx_audio/tts/models/kokoro/kokoro.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,11 @@

import mlx.core as mx
import mlx.nn as nn
import numpy as np
from loguru import logger

from ...utils import get_class_predicate
from ..base import GenerationResult, check_array_shape
from ..base import BaseModelArgs, GenerationResult, check_array_shape
from .istftnet import Decoder
from .modules import (
AdaLayerNorm,
AlbertModelArgs,
CustomAlbert,
ProsodyPredictor,
TextEncoder,
)
from .modules import AlbertModelArgs, CustomAlbert, ProsodyPredictor, TextEncoder
from .pipeline import KokoroPipeline

# Force reset logger configuration at the top of your file
Expand Down Expand Up @@ -52,6 +44,24 @@ def sanitize_lstm_weights(key: str, state_dict: mx.array) -> dict:
return {key: state_dict}


@dataclass
class ModelConfig(BaseModelArgs):
istftnet: dict
dim_in: int
dropout: float
hidden_dim: int
max_conv_dim: int
max_dur: int
multispeaker: bool
n_layer: int
n_mels: int
n_token: int
style_dim: int
text_encoder_kernel_size: int
plbert: dict
vocab: Dict[str, int]


class Model(nn.Module):
"""
KokoroModel is a torch.nn.Module with 2 main responsibilities:
Expand All @@ -69,37 +79,35 @@ class Model(nn.Module):

REPO_ID = "prince-canuma/Kokoro-82M"

def __init__(self, config: dict, repo_id: str = None):
def __init__(self, config: ModelConfig, repo_id: str = None):
super().__init__()
self.repo_id = repo_id
self.config = config
self.vocab = config["vocab"]
self.vocab = config.vocab
self.bert = CustomAlbert(
AlbertModelArgs(vocab_size=config["n_token"], **config["plbert"])
AlbertModelArgs(vocab_size=config.n_token, **config.plbert)
)

self.bert_encoder = nn.Linear(
self.bert.config.hidden_size, config["hidden_dim"]
)
self.bert_encoder = nn.Linear(self.bert.config.hidden_size, config.hidden_dim)
self.context_length = self.bert.config.max_position_embeddings
self.predictor = ProsodyPredictor(
style_dim=config["style_dim"],
d_hid=config["hidden_dim"],
nlayers=config["n_layer"],
max_dur=config["max_dur"],
dropout=config["dropout"],
style_dim=config.style_dim,
d_hid=config.hidden_dim,
nlayers=config.n_layer,
max_dur=config.max_dur,
dropout=config.dropout,
)
self.text_encoder = TextEncoder(
channels=config["hidden_dim"],
kernel_size=config["text_encoder_kernel_size"],
depth=config["n_layer"],
n_symbols=config["n_token"],
channels=config.hidden_dim,
kernel_size=config.text_encoder_kernel_size,
depth=config.n_layer,
n_symbols=config.n_token,
)
self.decoder = Decoder(
dim_in=config["hidden_dim"],
style_dim=config["style_dim"],
dim_out=config["n_mels"],
**config["istftnet"],
dim_in=config.hidden_dim,
style_dim=config.style_dim,
dim_out=config.n_mels,
**config.istftnet,
)

@dataclass
Expand Down
4 changes: 2 additions & 2 deletions mlx_audio/tts/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class TestKokoroModel(unittest.TestCase):
def test_init(self, mock_load_weights, mock_mx_load, mock_open, mock_json_load):
"""Test KokoroModel initialization."""
# Import inside the test method
from mlx_audio.tts.models.kokoro.kokoro import Model
from mlx_audio.tts.models.kokoro.kokoro import Model, ModelConfig

# Mock the config loading
config = {
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_init(self, mock_load_weights, mock_mx_load, mock_open, mock_json_load):
mock_load_weights.return_value = None

# Initialize the model with the config parameter
model = Model(config)
model = Model(ModelConfig.from_dict(config))

# Check that the model was initialized correctly
self.assertIsInstance(model, nn.Module)
Expand Down