77
88import mlx .core as mx
99import mlx .nn as nn
10- import numpy as np
1110from loguru import logger
1211
13- from ...utils import get_class_predicate
14- from ..base import GenerationResult , check_array_shape
12+ from ..base import BaseModelArgs , GenerationResult , check_array_shape
1513from .istftnet import Decoder
16- from .modules import (
17- AdaLayerNorm ,
18- AlbertModelArgs ,
19- CustomAlbert ,
20- ProsodyPredictor ,
21- TextEncoder ,
22- )
14+ from .modules import AlbertModelArgs , CustomAlbert , ProsodyPredictor , TextEncoder
2315from .pipeline import KokoroPipeline
2416
2517# Force reset logger configuration at the top of your file
@@ -52,6 +44,24 @@ def sanitize_lstm_weights(key: str, state_dict: mx.array) -> dict:
5244 return {key : state_dict }
5345
5446
47+ @dataclass
48+ class ModelConfig (BaseModelArgs ):
49+ istftnet : dict
50+ dim_in : int
51+ dropout : float
52+ hidden_dim : int
53+ max_conv_dim : int
54+ max_dur : int
55+ multispeaker : bool
56+ n_layer : int
57+ n_mels : int
58+ n_token : int
59+ style_dim : int
60+ text_encoder_kernel_size : int
61+ plbert : dict
62+ vocab : Dict [str , int ]
63+
64+
5565class Model (nn .Module ):
5666 """
5767 KokoroModel is a torch.nn.Module with 2 main responsibilities:
@@ -69,37 +79,35 @@ class Model(nn.Module):
6979
7080 REPO_ID = "prince-canuma/Kokoro-82M"
7181
72- def __init__ (self , config : dict , repo_id : str = None ):
82+ def __init__ (self , config : ModelConfig , repo_id : str = None ):
7383 super ().__init__ ()
7484 self .repo_id = repo_id
7585 self .config = config
76- self .vocab = config [ " vocab" ]
86+ self .vocab = config . vocab
7787 self .bert = CustomAlbert (
78- AlbertModelArgs (vocab_size = config [ " n_token" ] , ** config [ " plbert" ] )
88+ AlbertModelArgs (vocab_size = config . n_token , ** config . plbert )
7989 )
8090
81- self .bert_encoder = nn .Linear (
82- self .bert .config .hidden_size , config ["hidden_dim" ]
83- )
91+ self .bert_encoder = nn .Linear (self .bert .config .hidden_size , config .hidden_dim )
8492 self .context_length = self .bert .config .max_position_embeddings
8593 self .predictor = ProsodyPredictor (
86- style_dim = config [ " style_dim" ] ,
87- d_hid = config [ " hidden_dim" ] ,
88- nlayers = config [ " n_layer" ] ,
89- max_dur = config [ " max_dur" ] ,
90- dropout = config [ " dropout" ] ,
94+ style_dim = config . style_dim ,
95+ d_hid = config . hidden_dim ,
96+ nlayers = config . n_layer ,
97+ max_dur = config . max_dur ,
98+ dropout = config . dropout ,
9199 )
92100 self .text_encoder = TextEncoder (
93- channels = config [ " hidden_dim" ] ,
94- kernel_size = config [ " text_encoder_kernel_size" ] ,
95- depth = config [ " n_layer" ] ,
96- n_symbols = config [ " n_token" ] ,
101+ channels = config . hidden_dim ,
102+ kernel_size = config . text_encoder_kernel_size ,
103+ depth = config . n_layer ,
104+ n_symbols = config . n_token ,
97105 )
98106 self .decoder = Decoder (
99- dim_in = config [ " hidden_dim" ] ,
100- style_dim = config [ " style_dim" ] ,
101- dim_out = config [ " n_mels" ] ,
102- ** config [ " istftnet" ] ,
107+ dim_in = config . hidden_dim ,
108+ style_dim = config . style_dim ,
109+ dim_out = config . n_mels ,
110+ ** config . istftnet ,
103111 )
104112
105113 @dataclass
0 commit comments