Skip to content

Commit 6c513de

Browse files
authored
Merge pull request #591 from yoshphys/feature/irodori-tts
Add Irodori-TTS: Japanese TTS model port to MLX
2 parents 0cbb6d8 + 3c95188 commit 6c513de

File tree

9 files changed

+2023
-0
lines changed

9 files changed

+2023
-0
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Irodori TTS
2+
3+
Japanese text-to-speech model based on Echo TTS architecture, ported to MLX.
4+
Uses Rectified Flow diffusion with a DiT (Diffusion Transformer) and DACVAE codec (48kHz).
5+
6+
## Model
7+
8+
Original: [Aratako/Irodori-TTS-500M](https://huggingface.co/Aratako/Irodori-TTS-500M) (500M parameters)
9+
10+
## Usage
11+
12+
Python API:
13+
14+
```python
15+
from mlx_audio.tts import load
16+
17+
model = load("mlx-community/Irodori-TTS-500M-fp16")
18+
result = next(model.generate("こんにちは、音声合成のテストです。"))
19+
audio = result.audio
20+
```
21+
22+
With reference audio for voice cloning:
23+
24+
```python
25+
result = next(model.generate(
26+
"こんにちは、音声合成のテストです。",
27+
ref_audio="speaker.wav",
28+
))
29+
```
30+
31+
CLI:
32+
33+
```bash
34+
python -m mlx_audio.tts.generate \
35+
--model mlx-community/Irodori-TTS-500M-fp16 \
36+
--text "こんにちは、音声合成のテストです。"
37+
```
38+
39+
## Memory requirements
40+
41+
The default `sequence_length=750` requires approximately 24GB of unified memory.
42+
On 16GB machines, use reduced settings:
43+
44+
```python
45+
result = next(model.generate(
46+
"こんにちは。",
47+
sequence_length=300, # ~9GB
48+
cfg_guidance_mode="alternating", # ~1/3 of independent mode memory
49+
))
50+
```
51+
52+
Approximate memory usage with `cfg_guidance_mode="alternating"`:
53+
54+
| sequence_length | Memory | Audio length |
55+
|---|---|---|
56+
| 100 | ~2GB | ~4s |
57+
| 300 | ~2GB | ~12s |
58+
| 400 | ~3GB | ~16s |
59+
60+
With `cfg_guidance_mode="independent"` (default), multiply memory by ~3.
61+
62+
## Notes
63+
64+
- Input language: Japanese. Latin characters may not be pronounced correctly;
65+
convert them to katakana beforehand (e.g. "MLX" → "エムエルエックス").
66+
- The DACVAE codec weights (`facebook/dacvae-watermarked`) are automatically
67+
downloaded on first use.
68+
69+
## License
70+
71+
Irodori-TTS weights are released under the [MIT License](https://opensource.org/licenses/MIT).
72+
See [Aratako/Irodori-TTS-500M](https://huggingface.co/Aratako/Irodori-TTS-500M) for details.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .irodori_tts import Model, ModelConfig
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
from typing import Optional
5+
6+
from mlx_audio.tts.models.base import BaseModelArgs
7+
8+
9+
@dataclass
10+
class IrodoriDiTConfig(BaseModelArgs):
11+
# Audio latent dimensions (DACVAE: 128-dim, 48kHz)
12+
latent_dim: int = 128
13+
latent_patch_size: int = 1
14+
15+
# DiT backbone
16+
model_dim: int = 1280
17+
num_layers: int = 12
18+
num_heads: int = 20
19+
mlp_ratio: float = 2.875
20+
text_mlp_ratio: Optional[float] = 2.6
21+
speaker_mlp_ratio: Optional[float] = 2.6
22+
23+
# Text encoder
24+
text_vocab_size: int = 99574
25+
text_tokenizer_repo: str = "llm-jp/llm-jp-3-150m"
26+
text_add_bos: bool = True
27+
text_dim: int = 512
28+
text_layers: int = 10
29+
text_heads: int = 8
30+
31+
# Speaker (reference latent) encoder
32+
speaker_dim: int = 768
33+
speaker_layers: int = 8
34+
speaker_heads: int = 12
35+
speaker_patch_size: int = 1
36+
37+
# Conditioning
38+
timestep_embed_dim: int = 512
39+
adaln_rank: int = 192
40+
norm_eps: float = 1e-5
41+
42+
@property
43+
def patched_latent_dim(self) -> int:
44+
return self.latent_dim * self.latent_patch_size
45+
46+
@property
47+
def speaker_patched_latent_dim(self) -> int:
48+
return self.patched_latent_dim * self.speaker_patch_size
49+
50+
@property
51+
def text_mlp_ratio_resolved(self) -> float:
52+
return (
53+
self.mlp_ratio
54+
if self.text_mlp_ratio is None
55+
else float(self.text_mlp_ratio)
56+
)
57+
58+
@property
59+
def speaker_mlp_ratio_resolved(self) -> float:
60+
return (
61+
self.mlp_ratio
62+
if self.speaker_mlp_ratio is None
63+
else float(self.speaker_mlp_ratio)
64+
)
65+
66+
67+
@dataclass
68+
class SamplerConfig(BaseModelArgs):
69+
num_steps: int = 40
70+
cfg_scale_text: float = 3.0
71+
cfg_scale_speaker: float = 5.0
72+
cfg_guidance_mode: str = "independent"
73+
cfg_min_t: float = 0.5
74+
cfg_max_t: float = 1.0
75+
truncation_factor: Optional[float] = None
76+
rescale_k: Optional[float] = None
77+
rescale_sigma: Optional[float] = None
78+
context_kv_cache: bool = True
79+
speaker_kv_scale: Optional[float] = None
80+
speaker_kv_min_t: Optional[float] = 0.9
81+
speaker_kv_max_layers: Optional[int] = None
82+
sequence_length: int = 750
83+
84+
85+
@dataclass
86+
class ModelConfig(BaseModelArgs):
87+
model_type: str = "irodori_tts"
88+
sample_rate: int = 48000
89+
90+
max_text_length: int = 256
91+
max_speaker_latent_length: int = 6400
92+
# DACVAE hop_length = 2*8*10*12 = 1920
93+
audio_downsample_factor: int = 1920
94+
95+
dacvae_repo: str = "Aratako/Irodori-TTS-500M"
96+
model_path: Optional[str] = None
97+
98+
dit: IrodoriDiTConfig = field(default_factory=IrodoriDiTConfig)
99+
sampler: SamplerConfig = field(default_factory=SamplerConfig)
100+
101+
@classmethod
102+
def from_dict(cls, config: dict) -> "ModelConfig":
103+
return cls(
104+
model_type=config.get("model_type", "irodori_tts"),
105+
sample_rate=config.get("sample_rate", 48000),
106+
max_text_length=config.get("max_text_length", 256),
107+
max_speaker_latent_length=config.get("max_speaker_latent_length", 6400),
108+
audio_downsample_factor=config.get("audio_downsample_factor", 1920),
109+
dacvae_repo=config.get("dacvae_repo", "Aratako/Irodori-TTS-500M"),
110+
model_path=config.get("model_path"),
111+
dit=IrodoriDiTConfig.from_dict(config.get("dit", {})),
112+
sampler=SamplerConfig.from_dict(config.get("sampler", {})),
113+
)

0 commit comments

Comments
 (0)