Skip to content

Commit 9d03df4

Browse files
authored
Fix Kokoro audio generation (#52)
* Fix Kokoro audio generation. * Formatting. * Try to fix tests. * More test fixes.
1 parent 5f3cf93 commit 9d03df4

File tree

3 files changed

+41
-33
lines changed

3 files changed

+41
-33
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .kokoro import Model
1+
from .kokoro import Model, ModelConfig
22
from .pipeline import KokoroPipeline
33

4-
__all__ = ["KokoroPipeline", "Model"]
4+
__all__ = ["KokoroPipeline", "Model", "ModelConfig"]

mlx_audio/tts/models/kokoro/kokoro.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,11 @@
77

88
import mlx.core as mx
99
import mlx.nn as nn
10-
import numpy as np
1110
from loguru import logger
1211

13-
from ...utils import get_class_predicate
14-
from ..base import GenerationResult, check_array_shape
12+
from ..base import BaseModelArgs, GenerationResult, check_array_shape
1513
from .istftnet import Decoder
16-
from .modules import (
17-
AdaLayerNorm,
18-
AlbertModelArgs,
19-
CustomAlbert,
20-
ProsodyPredictor,
21-
TextEncoder,
22-
)
14+
from .modules import AlbertModelArgs, CustomAlbert, ProsodyPredictor, TextEncoder
2315
from .pipeline import KokoroPipeline
2416

2517
# Force reset logger configuration at the top of your file
@@ -52,6 +44,24 @@ def sanitize_lstm_weights(key: str, state_dict: mx.array) -> dict:
5244
return {key: state_dict}
5345

5446

47+
@dataclass
48+
class ModelConfig(BaseModelArgs):
49+
istftnet: dict
50+
dim_in: int
51+
dropout: float
52+
hidden_dim: int
53+
max_conv_dim: int
54+
max_dur: int
55+
multispeaker: bool
56+
n_layer: int
57+
n_mels: int
58+
n_token: int
59+
style_dim: int
60+
text_encoder_kernel_size: int
61+
plbert: dict
62+
vocab: Dict[str, int]
63+
64+
5565
class Model(nn.Module):
5666
"""
5767
KokoroModel is a torch.nn.Module with 2 main responsibilities:
@@ -69,37 +79,35 @@ class Model(nn.Module):
6979

7080
REPO_ID = "prince-canuma/Kokoro-82M"
7181

72-
def __init__(self, config: dict, repo_id: str = None):
82+
def __init__(self, config: ModelConfig, repo_id: str = None):
7383
super().__init__()
7484
self.repo_id = repo_id
7585
self.config = config
76-
self.vocab = config["vocab"]
86+
self.vocab = config.vocab
7787
self.bert = CustomAlbert(
78-
AlbertModelArgs(vocab_size=config["n_token"], **config["plbert"])
88+
AlbertModelArgs(vocab_size=config.n_token, **config.plbert)
7989
)
8090

81-
self.bert_encoder = nn.Linear(
82-
self.bert.config.hidden_size, config["hidden_dim"]
83-
)
91+
self.bert_encoder = nn.Linear(self.bert.config.hidden_size, config.hidden_dim)
8492
self.context_length = self.bert.config.max_position_embeddings
8593
self.predictor = ProsodyPredictor(
86-
style_dim=config["style_dim"],
87-
d_hid=config["hidden_dim"],
88-
nlayers=config["n_layer"],
89-
max_dur=config["max_dur"],
90-
dropout=config["dropout"],
94+
style_dim=config.style_dim,
95+
d_hid=config.hidden_dim,
96+
nlayers=config.n_layer,
97+
max_dur=config.max_dur,
98+
dropout=config.dropout,
9199
)
92100
self.text_encoder = TextEncoder(
93-
channels=config["hidden_dim"],
94-
kernel_size=config["text_encoder_kernel_size"],
95-
depth=config["n_layer"],
96-
n_symbols=config["n_token"],
101+
channels=config.hidden_dim,
102+
kernel_size=config.text_encoder_kernel_size,
103+
depth=config.n_layer,
104+
n_symbols=config.n_token,
97105
)
98106
self.decoder = Decoder(
99-
dim_in=config["hidden_dim"],
100-
style_dim=config["style_dim"],
101-
dim_out=config["n_mels"],
102-
**config["istftnet"],
107+
dim_in=config.hidden_dim,
108+
style_dim=config.style_dim,
109+
dim_out=config.n_mels,
110+
**config.istftnet,
103111
)
104112

105113
@dataclass

mlx_audio/tts/tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class TestKokoroModel(unittest.TestCase):
8686
def test_init(self, mock_load_weights, mock_mx_load, mock_open, mock_json_load):
8787
"""Test KokoroModel initialization."""
8888
# Import inside the test method
89-
from mlx_audio.tts.models.kokoro.kokoro import Model
89+
from mlx_audio.tts.models.kokoro.kokoro import Model, ModelConfig
9090

9191
# Mock the config loading
9292
config = {
@@ -129,7 +129,7 @@ def test_init(self, mock_load_weights, mock_mx_load, mock_open, mock_json_load):
129129
mock_load_weights.return_value = None
130130

131131
# Initialize the model with the config parameter
132-
model = Model(config)
132+
model = Model(ModelConfig.from_dict(config))
133133

134134
# Check that the model was initialized correctly
135135
self.assertIsInstance(model, nn.Module)

0 commit comments

Comments
 (0)