diff --git a/mlx_audio/tts/models/irodori_tts/README.md b/mlx_audio/tts/models/irodori_tts/README.md new file mode 100644 index 00000000..42aac538 --- /dev/null +++ b/mlx_audio/tts/models/irodori_tts/README.md @@ -0,0 +1,72 @@ +# Irodori TTS + +Japanese text-to-speech model based on Echo TTS architecture, ported to MLX. +Uses Rectified Flow diffusion with a DiT (Diffusion Transformer) and DACVAE codec (48kHz). + +## Model + +Original: [Aratako/Irodori-TTS-500M](https://huggingface.co/Aratako/Irodori-TTS-500M) (500M parameters) + +## Usage + +Python API: + +```python +from mlx_audio.tts import load + +model = load("mlx-community/Irodori-TTS-500M-fp16") +result = next(model.generate("こんにちは、音声合成のテストです。")) +audio = result.audio +``` + +With reference audio for voice cloning: + +```python +result = next(model.generate( + "こんにちは、音声合成のテストです。", + ref_audio="speaker.wav", +)) +``` + +CLI: + +```bash +python -m mlx_audio.tts.generate \ + --model mlx-community/Irodori-TTS-500M-fp16 \ + --text "こんにちは、音声合成のテストです。" +``` + +## Memory requirements + +The default `sequence_length=750` requires approximately 24GB of unified memory. +On 16GB machines, use reduced settings: + +```python +result = next(model.generate( + "こんにちは。", + sequence_length=300, # ~9GB + cfg_guidance_mode="alternating", # ~1/3 of independent mode memory +)) +``` + +Approximate memory usage with `cfg_guidance_mode="alternating"`: + +| sequence_length | Memory | Audio length | +|---|---|---| +| 100 | ~2GB | ~4s | +| 300 | ~2GB | ~12s | +| 400 | ~3GB | ~16s | + +With `cfg_guidance_mode="independent"` (default), multiply memory by ~3. + +## Notes + +- Input language: Japanese. Latin characters may not be pronounced correctly; + convert them to katakana beforehand (e.g. "MLX" → "エムエルエックス"). +- The DACVAE codec weights (`facebook/dacvae-watermarked`) are automatically + downloaded on first use. + +## License + +Irodori-TTS weights are released under the [MIT License](https://opensource.org/licenses/MIT). +See [Aratako/Irodori-TTS-500M](https://huggingface.co/Aratako/Irodori-TTS-500M) for details. diff --git a/mlx_audio/tts/models/irodori_tts/__init__.py b/mlx_audio/tts/models/irodori_tts/__init__.py new file mode 100644 index 00000000..6f748d9c --- /dev/null +++ b/mlx_audio/tts/models/irodori_tts/__init__.py @@ -0,0 +1 @@ +from .irodori_tts import Model, ModelConfig diff --git a/mlx_audio/tts/models/irodori_tts/config.py b/mlx_audio/tts/models/irodori_tts/config.py new file mode 100644 index 00000000..57ffd6d7 --- /dev/null +++ b/mlx_audio/tts/models/irodori_tts/config.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional + +from mlx_audio.tts.models.base import BaseModelArgs + + +@dataclass +class IrodoriDiTConfig(BaseModelArgs): + # Audio latent dimensions (DACVAE: 128-dim, 48kHz) + latent_dim: int = 128 + latent_patch_size: int = 1 + + # DiT backbone + model_dim: int = 1280 + num_layers: int = 12 + num_heads: int = 20 + mlp_ratio: float = 2.875 + text_mlp_ratio: Optional[float] = 2.6 + speaker_mlp_ratio: Optional[float] = 2.6 + + # Text encoder + text_vocab_size: int = 99574 + text_tokenizer_repo: str = "llm-jp/llm-jp-3-150m" + text_add_bos: bool = True + text_dim: int = 512 + text_layers: int = 10 + text_heads: int = 8 + + # Speaker (reference latent) encoder + speaker_dim: int = 768 + speaker_layers: int = 8 + speaker_heads: int = 12 + speaker_patch_size: int = 1 + + # Conditioning + timestep_embed_dim: int = 512 + adaln_rank: int = 192 + norm_eps: float = 1e-5 + + @property + def patched_latent_dim(self) -> int: + return self.latent_dim * self.latent_patch_size + + @property + def speaker_patched_latent_dim(self) -> int: + return self.patched_latent_dim * self.speaker_patch_size + + @property + def text_mlp_ratio_resolved(self) -> float: + return ( + self.mlp_ratio + if self.text_mlp_ratio is None + else float(self.text_mlp_ratio) + ) + + @property + def speaker_mlp_ratio_resolved(self) -> float: + return ( + self.mlp_ratio + if self.speaker_mlp_ratio is None + else float(self.speaker_mlp_ratio) + ) + + +@dataclass +class SamplerConfig(BaseModelArgs): + num_steps: int = 40 + cfg_scale_text: float = 3.0 + cfg_scale_speaker: float = 5.0 + cfg_guidance_mode: str = "independent" + cfg_min_t: float = 0.5 + cfg_max_t: float = 1.0 + truncation_factor: Optional[float] = None + rescale_k: Optional[float] = None + rescale_sigma: Optional[float] = None + context_kv_cache: bool = True + speaker_kv_scale: Optional[float] = None + speaker_kv_min_t: Optional[float] = 0.9 + speaker_kv_max_layers: Optional[int] = None + sequence_length: int = 750 + + +@dataclass +class ModelConfig(BaseModelArgs): + model_type: str = "irodori_tts" + sample_rate: int = 48000 + + max_text_length: int = 256 + max_speaker_latent_length: int = 6400 + # DACVAE hop_length = 2*8*10*12 = 1920 + audio_downsample_factor: int = 1920 + + dacvae_repo: str = "Aratako/Irodori-TTS-500M" + model_path: Optional[str] = None + + dit: IrodoriDiTConfig = field(default_factory=IrodoriDiTConfig) + sampler: SamplerConfig = field(default_factory=SamplerConfig) + + @classmethod + def from_dict(cls, config: dict) -> "ModelConfig": + return cls( + model_type=config.get("model_type", "irodori_tts"), + sample_rate=config.get("sample_rate", 48000), + max_text_length=config.get("max_text_length", 256), + max_speaker_latent_length=config.get("max_speaker_latent_length", 6400), + audio_downsample_factor=config.get("audio_downsample_factor", 1920), + dacvae_repo=config.get("dacvae_repo", "Aratako/Irodori-TTS-500M"), + model_path=config.get("model_path"), + dit=IrodoriDiTConfig.from_dict(config.get("dit", {})), + sampler=SamplerConfig.from_dict(config.get("sampler", {})), + ) diff --git a/mlx_audio/tts/models/irodori_tts/irodori_tts.py b/mlx_audio/tts/models/irodori_tts/irodori_tts.py new file mode 100644 index 00000000..11145629 --- /dev/null +++ b/mlx_audio/tts/models/irodori_tts/irodori_tts.py @@ -0,0 +1,305 @@ +from __future__ import annotations + +import time +from pathlib import Path +from typing import Generator, Optional + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from mlx_audio.codec.models.dacvae import DACVAE +from mlx_audio.tts.models.base import GenerationResult +from mlx_audio.utils import load_audio as load_audio_any + +from .config import ModelConfig +from .model import IrodoriDiT +from .sampling import sample_euler_cfg +from .text import encode_text, normalize_text + + +def _find_silence_point( + latent: mx.array, + window_size: int = 20, + std_threshold: float = 0.05, +) -> int: + """Detect trailing silence in generated latent (same heuristic as Echo TTS).""" + # latent: (T, D) + padded = mx.concatenate( + [latent, mx.zeros((window_size, latent.shape[-1]), dtype=latent.dtype)], axis=0 + ) + for i in range(int(padded.shape[0] - window_size)): + window = padded[i : i + window_size] + if float(window.std()) < std_threshold and abs(float(window.mean())) < 0.1: + return i + return int(latent.shape[0]) + + +class Model(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + self.model = IrodoriDiT(config.dit) + self.dacvae: DACVAE | None = None + self._tokenizer = None + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def sample_rate(self) -> int: + return self.config.sample_rate + + @property + def model_type(self) -> str: + return self.config.model_type + + def __call__(self, *args, **kwargs): + return self.model(*args, **kwargs) + + # ------------------------------------------------------------------ + # Weight loading hooks + # ------------------------------------------------------------------ + + def sanitize(self, weights: dict) -> dict: + """ + Remap Irodori PyTorch weight keys to mlx-audio MLX conventions: + cond_module.0.weight → cond_module.layers.0.weight + → model. + """ + out = {} + for k, v in weights.items(): + # PyTorch Sequential uses integer keys; MLX nn.Sequential uses "layers.N" + if k.startswith("cond_module."): + parts = k.split(".") + if len(parts) > 1 and parts[1].isdigit(): + k = ".".join(["cond_module", "layers", parts[1], *parts[2:]]) + # Nest under self.model + out_key = f"model.{k}" if not k.startswith("model.") else k + out[out_key] = v + return out + + @classmethod + def post_load_hook(cls, model: "Model", model_path: Path) -> "Model": + """Load DACVAE codec after model weights are loaded. + + Looks for a local dacvae/ subdirectory first (produced by convert.py), + then falls back to downloading model.config.dacvae_repo from HuggingFace. + """ + import json as _json + + from mlx_audio.codec.models.dacvae import DACVAEConfig + + local_dacvae = Path(model_path) / "dacvae" + try: + if local_dacvae.is_dir(): + with open(local_dacvae / "config.json") as f: + cfg = DACVAEConfig(**_json.load(f)) + dac = DACVAE(cfg) + dac.load_weights(str(local_dacvae / "model.safetensors")) + import mlx.core as _mx + + _mx.eval(dac.parameters()) + model.dacvae = dac + else: + model.dacvae = DACVAE.from_pretrained(model.config.dacvae_repo) + except Exception as e: + import warnings + + warnings.warn( + f"Could not load DACVAE: {e}\n" + "Set model.dacvae manually before calling generate()." + ) + model.dacvae = None + return model + + # ------------------------------------------------------------------ + # Tokenisation + # ------------------------------------------------------------------ + + def _get_tokenizer(self): + if self._tokenizer is None: + from transformers import AutoTokenizer + + self._tokenizer = AutoTokenizer.from_pretrained( + self.config.dit.text_tokenizer_repo + ) + return self._tokenizer + + def _prepare_text( + self, text: str, max_length: Optional[int] = None + ) -> tuple[mx.array, mx.array]: + """ + Normalise and tokenise text, returning (input_ids, mask) as MLX arrays. + Matches Irodori's PretrainedTextTokenizer: BOS prepended manually, + right-padded to max_length. + """ + if max_length is None: + max_length = self.config.max_text_length + + text = normalize_text(text) + return encode_text( + text, + tokenizer=self._get_tokenizer(), + max_length=max_length, + add_bos=self.config.dit.text_add_bos, + ) + + # ------------------------------------------------------------------ + # Reference audio encoding + # ------------------------------------------------------------------ + + def _encode_ref_audio(self, audio: mx.array) -> tuple[mx.array, mx.array]: + """ + Encode reference waveform with DACVAE. + audio: (1, samples) at config.sample_rate + Returns (latent, mask): latent (1, T, 128), mask (1, T) bool + """ + assert self.dacvae is not None, "DACVAE not loaded" + + max_samples = ( + self.config.max_speaker_latent_length * self.config.audio_downsample_factor + ) + audio = audio[:, :max_samples] + + # DACVAE encode expects (B, L, 1) + audio_in = audio[:, :, None] # (1, L, 1) + latent = self.dacvae.encode(audio_in) # (1, 128, T) channels-first + latent = mx.transpose(latent, (0, 2, 1)) # (1, T, 128) sequence-first + + actual_t = int(audio.shape[1]) // self.config.audio_downsample_factor + actual_t = min(actual_t, latent.shape[1]) + latent = latent[:, :actual_t] + mask = mx.ones((1, actual_t), dtype=mx.bool_) + + # Align to speaker_patch_size + p = self.config.dit.speaker_patch_size + if p > 1 and actual_t % p != 0: + trim = (actual_t // p) * p + latent = latent[:, :trim] + mask = mask[:, :trim] + + return latent, mask + + # ------------------------------------------------------------------ + # Latent generation (sampling) + # ------------------------------------------------------------------ + + def generate_latents( + self, + text: str, + ref_latent: Optional[mx.array] = None, + ref_mask: Optional[mx.array] = None, + rng_seed: int = 0, + **sampling_kwargs, + ) -> mx.array: + text_input_ids, text_mask = self._prepare_text(text) + + if ref_latent is None: + ref_latent = mx.zeros((1, 1, self.config.dit.latent_dim)) + if ref_mask is None: + ref_mask = mx.zeros((1, ref_latent.shape[1]), dtype=mx.bool_) + + sampler_cfg = dict(self.config.sampler.__dict__) + for k, v in sampling_kwargs.items(): + if k in sampler_cfg: + sampler_cfg[k] = v + + return sample_euler_cfg( + model=self.model, + text_input_ids=text_input_ids, + text_mask=text_mask, + ref_latent=ref_latent, + ref_mask=ref_mask, + rng_seed=rng_seed, + latent_dim=self.config.dit.patched_latent_dim, + **sampler_cfg, + ) + + # ------------------------------------------------------------------ + # Main generate interface + # ------------------------------------------------------------------ + + def generate( + self, + text: str, + voice: str | None = None, + ref_audio: str | mx.array | None = None, + stream: bool = False, + **kwargs, + ) -> Generator[GenerationResult, None, None]: + if stream: + raise NotImplementedError("Irodori-TTS streaming is not yet implemented.") + + if self.dacvae is None: + raise ValueError( + "Irodori-TTS requires DACVAE to be loaded. " + "Use mlx_audio.tts.load(...) or set model.dacvae manually." + ) + + start_time = time.perf_counter() + text_input_ids, _ = self._prepare_text(text) + token_count = int(text_input_ids.shape[1]) + + # Encode reference audio if provided + ref_latent = None + ref_mask = None + if ref_audio is not None: + audio = ( + load_audio_any(ref_audio, sample_rate=self.sample_rate) + if isinstance(ref_audio, str) + else ref_audio + ) + if audio.ndim == 1: + audio = audio[None, :] + elif audio.ndim == 2 and audio.shape[0] > 1: + audio = mx.mean(audio, axis=0, keepdims=True) + ref_latent, ref_mask = self._encode_ref_audio(audio) + + # Run diffusion sampler + latent_out = self.generate_latents( + text=text, + ref_latent=ref_latent, + ref_mask=ref_mask, + rng_seed=int(kwargs.get("rng_seed", 0)), + **{k: v for k, v in kwargs.items() if k != "rng_seed"}, + ) + + # Decode latent → waveform + # latent_out: (1, T, 128) + latent_for_decode = mx.transpose(latent_out, (0, 2, 1)) # (1, 128, T) + audio_out = self.dacvae.decode(latent_for_decode) # (1, L, 1) + audio_out = audio_out[:, :, 0] # (1, L) + + # Trim trailing silence + silence_t = _find_silence_point(latent_out[0]) + trim_samples = silence_t * self.config.audio_downsample_factor + audio_out = audio_out[:, :trim_samples] + + audio = audio_out[0] # (L,) + samples = int(audio.shape[0]) + elapsed = max(time.perf_counter() - start_time, 1e-6) + audio_duration_seconds = ( + samples / self.sample_rate if self.sample_rate > 0 else 0.0 + ) + + h = int(audio_duration_seconds // 3600) + m = int((audio_duration_seconds % 3600) // 60) + s = int(audio_duration_seconds % 60) + ms = int((audio_duration_seconds % 1) * 1000) + duration_str = f"{h:02d}:{m:02d}:{s:02d}.{ms:03d}" + + yield GenerationResult( + audio=audio, + samples=samples, + sample_rate=self.sample_rate, + segment_idx=0, + token_count=token_count, + audio_duration=duration_str, + real_time_factor=audio_duration_seconds / elapsed, + prompt={"tokens": token_count, "tokens-per-sec": token_count / elapsed}, + audio_samples={"samples": samples, "samples-per-sec": samples / elapsed}, + processing_time_seconds=elapsed, + peak_memory_usage=float(mx.get_peak_memory() / 1e9), + ) diff --git a/mlx_audio/tts/models/irodori_tts/model.py b/mlx_audio/tts/models/irodori_tts/model.py new file mode 100644 index 00000000..368cb067 --- /dev/null +++ b/mlx_audio/tts/models/irodori_tts/model.py @@ -0,0 +1,635 @@ +from __future__ import annotations + +import math +from typing import List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .config import IrodoriDiTConfig + +RotaryCache = Tuple[mx.array, mx.array] +KVCache = Tuple[mx.array, mx.array] + + +# --------------------------------------------------------------------------- +# Positional encoding helpers +# --------------------------------------------------------------------------- + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> RotaryCache: + freqs = 1.0 / ( + theta ** (mx.arange(0, dim, 2, dtype=mx.float32)[: (dim // 2)] / float(dim)) + ) + t = mx.arange(end, dtype=mx.float32) + freqs = mx.outer(t, freqs) + return mx.cos(freqs), mx.sin(freqs) + + +def apply_rotary_emb(x: mx.array, freqs_cis: RotaryCache) -> mx.array: + cos, sin = freqs_cis + x_even = x[..., 0::2] + x_odd = x[..., 1::2] + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + x_rot_even = x_even * cos - x_odd * sin + x_rot_odd = x_odd * cos + x_even * sin + return mx.stack([x_rot_even, x_rot_odd], axis=-1).reshape(x.shape) + + +def get_timestep_embedding(timestep: mx.array, embed_size: int) -> mx.array: + if embed_size % 2 != 0: + raise ValueError("embed_size must be even") + half = embed_size // 2 + base = mx.log(mx.array(10000.0, dtype=mx.float32)) + freqs = 1000.0 * mx.exp( + -base * mx.arange(start=0, stop=half, dtype=mx.float32) / float(half) + ) + args = timestep[..., None] * freqs[None, :] + return mx.concatenate([mx.cos(args), mx.sin(args)], axis=-1).astype(timestep.dtype) + + +def _bool_to_additive_mask(mask: mx.array) -> mx.array: + """Convert boolean mask (B, Sq, Sk) to additive float mask (B, 1, Sq, Sk).""" + zero = mx.zeros(mask.shape, dtype=mx.float32) + neg_inf = mx.full(mask.shape, -1e9, dtype=mx.float32) + return mx.where(mask, zero, neg_inf)[:, None, :, :] + + +def patch_sequence_with_mask( + seq: mx.array, + mask: mx.array, + patch_size: int, +) -> Tuple[mx.array, mx.array]: + """ + Patch along the sequence axis. + seq : (B, S, D) -> (B, S//patch, D*patch) + mask : (B, S) bool -> (B, S//patch) bool (True iff all tokens in patch are valid) + """ + if patch_size <= 1: + return seq, mask + bsz, seq_len, dim = seq.shape + usable = (seq_len // patch_size) * patch_size + seq = seq[:, :usable].reshape(bsz, usable // patch_size, dim * patch_size) + mask = mask[:, :usable].reshape(bsz, usable // patch_size, patch_size) + mask = mx.all(mask, axis=-1) + return seq, mask + + +# --------------------------------------------------------------------------- +# Normalisation layers +# --------------------------------------------------------------------------- + + +class RMSNorm(nn.Module): + def __init__(self, model_size: int | Tuple[int, int], eps: float): + super().__init__() + self.eps = eps + if isinstance(model_size, int): + model_size = (model_size,) + self.weight = mx.ones(model_size) + + def __call__(self, x: mx.array) -> mx.array: + x_dtype = x.dtype + x = x.astype(mx.float32) + x = x * mx.rsqrt(mx.mean(mx.square(x), axis=-1, keepdims=True) + self.eps) + return (x * self.weight).astype(x_dtype) + + +class LowRankAdaLN(nn.Module): + """ + Low-rank adaptive layer normalisation with shift/scale/gate from timestep embedding. + Matches Irodori's LowRankAdaLN exactly (including residual connection on each branch). + """ + + def __init__(self, model_dim: int, rank: int, eps: float): + super().__init__() + self.eps = eps + rank = max(1, min(int(rank), int(model_dim))) + self.shift_down = nn.Linear(model_dim, rank, bias=False) + self.scale_down = nn.Linear(model_dim, rank, bias=False) + self.gate_down = nn.Linear(model_dim, rank, bias=False) + self.shift_up = nn.Linear(rank, model_dim, bias=True) + self.scale_up = nn.Linear(rank, model_dim, bias=True) + self.gate_up = nn.Linear(rank, model_dim, bias=True) + + def __call__(self, x: mx.array, cond_embed: mx.array) -> Tuple[mx.array, mx.array]: + shift, scale, gate = mx.split(cond_embed, 3, axis=-1) + shift = self.shift_up(self.shift_down(nn.silu(shift))) + shift + scale = self.scale_up(self.scale_down(nn.silu(scale))) + scale + gate = self.gate_up(self.gate_down(nn.silu(gate))) + gate + + x_dtype = x.dtype + x = x.astype(mx.float32) + x = x * mx.rsqrt(mx.mean(mx.square(x), axis=-1, keepdims=True) + self.eps) + x = x * (1.0 + scale) + shift + gate = mx.tanh(gate) + return x.astype(x_dtype), gate + + +# --------------------------------------------------------------------------- +# Feed-forward +# --------------------------------------------------------------------------- + + +class SwiGLU(nn.Module): + """SwiGLU MLP: w2(silu(w1(x)) * w3(x)).""" + + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x: mx.array) -> mx.array: + return self.w2(nn.silu(self.w1(x)) * self.w3(x)) + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + + +class SelfAttention(nn.Module): + """Non-causal self-attention with RoPE and output gate (used in encoders).""" + + def __init__(self, dim: int, heads: int, norm_eps: float): + super().__init__() + self.heads = heads + self.head_dim = dim // heads + self.wq = nn.Linear(dim, dim, bias=False) + self.wk = nn.Linear(dim, dim, bias=False) + self.wv = nn.Linear(dim, dim, bias=False) + self.wo = nn.Linear(dim, dim, bias=False) + self.gate = nn.Linear(dim, dim, bias=False) + self.q_norm = RMSNorm((heads, self.head_dim), eps=norm_eps) + self.k_norm = RMSNorm((heads, self.head_dim), eps=norm_eps) + + def __call__( + self, + x: mx.array, + key_mask: Optional[mx.array], + freqs_cis: RotaryCache, + ) -> mx.array: + bsz, seq_len = x.shape[:2] + q = self.wq(x).reshape(bsz, seq_len, self.heads, self.head_dim) + k = self.wk(x).reshape(bsz, seq_len, self.heads, self.head_dim) + v = self.wv(x).reshape(bsz, seq_len, self.heads, self.head_dim) + gate = self.gate(x) + + q = self.q_norm(q) + k = self.k_norm(k) + q = apply_rotary_emb(q, (freqs_cis[0][:seq_len], freqs_cis[1][:seq_len])) + k = apply_rotary_emb(k, (freqs_cis[0][:seq_len], freqs_cis[1][:seq_len])) + + attn_mask = None + if key_mask is not None: + m = mx.broadcast_to(key_mask[:, None, :], (bsz, seq_len, seq_len)) + attn_mask = _bool_to_additive_mask(m) + + out = mx.fast.scaled_dot_product_attention( + q=mx.transpose(q, (0, 2, 1, 3)), + k=mx.transpose(k, (0, 2, 1, 3)), + v=mx.transpose(v, (0, 2, 1, 3)), + scale=1.0 / math.sqrt(self.head_dim), + mask=attn_mask, + ) + out = mx.transpose(out, (0, 2, 1, 3)).reshape(bsz, seq_len, -1) + return self.wo(out * mx.sigmoid(gate)) + + +class JointAttention(nn.Module): + """ + Joint attention over latent self-tokens, text context, and speaker context. + Uses half-RoPE: RoPE applied to the first half of head dimensions. + No latent KV cache (blockwise generation not needed for inference). + """ + + def __init__( + self, + dim: int, + heads: int, + text_ctx_dim: int, + speaker_ctx_dim: int, + norm_eps: float, + ): + super().__init__() + self.heads = heads + self.head_dim = dim // heads + self.wq = nn.Linear(dim, dim, bias=False) + self.wk = nn.Linear(dim, dim, bias=False) + self.wv = nn.Linear(dim, dim, bias=False) + self.wk_text = nn.Linear(text_ctx_dim, dim, bias=False) + self.wv_text = nn.Linear(text_ctx_dim, dim, bias=False) + self.wk_speaker = nn.Linear(speaker_ctx_dim, dim, bias=False) + self.wv_speaker = nn.Linear(speaker_ctx_dim, dim, bias=False) + self.gate = nn.Linear(dim, dim, bias=False) + self.wo = nn.Linear(dim, dim, bias=False) + self.q_norm = RMSNorm((heads, self.head_dim), eps=norm_eps) + self.k_norm = RMSNorm((heads, self.head_dim), eps=norm_eps) + + def _apply_rotary_half(self, y: mx.array, freqs_cis: RotaryCache) -> mx.array: + """Apply RoPE to the first half of head dimensions only.""" + half = y.shape[-2] // 2 + y1 = apply_rotary_emb(y[..., :half, :], freqs_cis) + return mx.concatenate([y1, y[..., half:, :]], axis=-2) + + def get_kv_cache_text(self, text_state: mx.array) -> KVCache: + bsz = text_state.shape[0] + k = self.wk_text(text_state).reshape( + bsz, text_state.shape[1], self.heads, self.head_dim + ) + v = self.wv_text(text_state).reshape( + bsz, text_state.shape[1], self.heads, self.head_dim + ) + k = self.k_norm(k) + return k, v + + def get_kv_cache_speaker(self, speaker_state: mx.array) -> KVCache: + bsz = speaker_state.shape[0] + k = self.wk_speaker(speaker_state).reshape( + bsz, speaker_state.shape[1], self.heads, self.head_dim + ) + v = self.wv_speaker(speaker_state).reshape( + bsz, speaker_state.shape[1], self.heads, self.head_dim + ) + k = self.k_norm(k) + return k, v + + def __call__( + self, + x: mx.array, + text_mask: mx.array, + speaker_mask: mx.array, + freqs_cis: RotaryCache, + kv_cache_text: KVCache, + kv_cache_speaker: KVCache, + start_pos: int = 0, + ) -> mx.array: + bsz, seq_len = x.shape[:2] + q = self.wq(x).reshape(bsz, seq_len, self.heads, self.head_dim) + k_self = self.wk(x).reshape(bsz, seq_len, self.heads, self.head_dim) + v_self = self.wv(x).reshape(bsz, seq_len, self.heads, self.head_dim) + gate = self.gate(x) + + q = self.q_norm(q) + k_self = self.k_norm(k_self) + + q_cos = freqs_cis[0][start_pos : start_pos + seq_len] + q_sin = freqs_cis[1][start_pos : start_pos + seq_len] + q = self._apply_rotary_half(q, (q_cos, q_sin)) + k_self = self._apply_rotary_half(k_self, (q_cos, q_sin)) + + k_text, v_text = kv_cache_text + k_speaker, v_speaker = kv_cache_speaker + + k = mx.concatenate([k_self, k_text, k_speaker], axis=1) + v = mx.concatenate([v_self, v_text, v_speaker], axis=1) + + self_mask = mx.ones((bsz, seq_len), dtype=mx.bool_) + full_mask = mx.concatenate([self_mask, text_mask, speaker_mask], axis=1) + full_mask = mx.broadcast_to( + full_mask[:, None, :], (bsz, seq_len, full_mask.shape[1]) + ) + attn_mask = _bool_to_additive_mask(full_mask) + + out = mx.fast.scaled_dot_product_attention( + q=mx.transpose(q, (0, 2, 1, 3)), + k=mx.transpose(k, (0, 2, 1, 3)), + v=mx.transpose(v, (0, 2, 1, 3)), + scale=1.0 / math.sqrt(self.head_dim), + mask=attn_mask, + ) + out = mx.transpose(out, (0, 2, 1, 3)).reshape(bsz, seq_len, -1) + return self.wo(out * mx.sigmoid(gate)) + + +# --------------------------------------------------------------------------- +# Encoder blocks +# --------------------------------------------------------------------------- + + +class TextBlock(nn.Module): + """Transformer block used in both TextEncoder and ReferenceLatentEncoder.""" + + def __init__(self, dim: int, heads: int, mlp_hidden_dim: int, norm_eps: float): + super().__init__() + self.attention_norm = RMSNorm(dim, eps=norm_eps) + self.attention = SelfAttention(dim, heads, norm_eps=norm_eps) + self.mlp_norm = RMSNorm(dim, eps=norm_eps) + self.mlp = SwiGLU(dim, mlp_hidden_dim) + + def __call__( + self, x: mx.array, mask: Optional[mx.array], freqs_cis: RotaryCache + ) -> mx.array: + x = x + self.attention( + self.attention_norm(x), key_mask=mask, freqs_cis=freqs_cis + ) + x = x + self.mlp(self.mlp_norm(x)) + return x + + +class TextEncoder(nn.Module): + """ + Text encoder: embedding + non-causal Transformer blocks. + Applies mask zeroing after each block so fully-masked positions stay zero. + """ + + def __init__( + self, + vocab_size: int, + dim: int, + heads: int, + num_layers: int, + mlp_ratio: float, + norm_eps: float, + ): + super().__init__() + self.head_dim = dim // heads + self.text_embedding = nn.Embedding(vocab_size, dim) + mlp_hidden = int(dim * mlp_ratio) + self.blocks = [ + TextBlock(dim, heads, mlp_hidden, norm_eps) for _ in range(num_layers) + ] + + def __call__( + self, input_ids: mx.array, mask: Optional[mx.array] = None + ) -> mx.array: + x = self.text_embedding(input_ids) + freqs_cis = precompute_freqs_cis(self.head_dim, input_ids.shape[1]) + if mask is not None: + mask_f = mask[..., None].astype(x.dtype) + x = x * mask_f + for block in self.blocks: + x = block(x, mask=mask, freqs_cis=freqs_cis) + x = x * mask_f + return x * mask_f + else: + for block in self.blocks: + x = block(x, mask=None, freqs_cis=freqs_cis) + return x + + +class ReferenceLatentEncoder(nn.Module): + """ + Encoder for reference (speaker) audio latents. + Receives already-patched DACVAE latents: (B, S, latent_dim * speaker_patch_size). + Uses non-causal attention (unlike Echo TTS which uses causal). + """ + + def __init__( + self, + in_dim: int, + dim: int, + heads: int, + num_layers: int, + mlp_ratio: float, + norm_eps: float, + ): + super().__init__() + self.head_dim = dim // heads + self.in_proj = nn.Linear(in_dim, dim, bias=True) + mlp_hidden = int(dim * mlp_ratio) + self.blocks = [ + TextBlock(dim, heads, mlp_hidden, norm_eps) for _ in range(num_layers) + ] + + def __call__(self, latent: mx.array, mask: Optional[mx.array] = None) -> mx.array: + x = self.in_proj(latent) / 6.0 + freqs_cis = precompute_freqs_cis(self.head_dim, x.shape[1]) + if mask is not None: + mask_f = mask[..., None].astype(x.dtype) + x = x * mask_f + for block in self.blocks: + x = block(x, mask=mask, freqs_cis=freqs_cis) + x = x * mask_f + return x * mask_f + else: + for block in self.blocks: + x = block(x, mask=None, freqs_cis=freqs_cis) + return x + + +# --------------------------------------------------------------------------- +# Diffusion block +# --------------------------------------------------------------------------- + + +class DiffusionBlock(nn.Module): + """ + Single DiT block: JointAttention + SwiGLU, both conditioned via LowRankAdaLN. + """ + + def __init__( + self, + dim: int, + heads: int, + mlp_hidden_dim: int, + text_ctx_dim: int, + speaker_ctx_dim: int, + adaln_rank: int, + norm_eps: float, + ): + super().__init__() + self.attention = JointAttention( + dim, heads, text_ctx_dim, speaker_ctx_dim, norm_eps + ) + self.mlp = SwiGLU(dim, mlp_hidden_dim) + self.attention_adaln = LowRankAdaLN(dim, adaln_rank, norm_eps) + self.mlp_adaln = LowRankAdaLN(dim, adaln_rank, norm_eps) + + def __call__( + self, + x: mx.array, + cond_embed: mx.array, + text_mask: mx.array, + speaker_mask: mx.array, + freqs_cis: RotaryCache, + kv_cache_text: KVCache, + kv_cache_speaker: KVCache, + start_pos: int = 0, + ) -> mx.array: + x_norm, attn_gate = self.attention_adaln(x, cond_embed) + x = x + attn_gate * self.attention( + x_norm, + text_mask, + speaker_mask, + freqs_cis, + kv_cache_text, + kv_cache_speaker, + start_pos, + ) + x_norm, mlp_gate = self.mlp_adaln(x, cond_embed) + x = x + mlp_gate * self.mlp(x_norm) + return x + + +# --------------------------------------------------------------------------- +# Main DiT model +# --------------------------------------------------------------------------- + + +class IrodoriDiT(nn.Module): + """ + Irodori-TTS DiT model (MLX port of TextToLatentRFDiT). + + Input x_t : (B, S, latent_dim * latent_patch_size) + Output v_t : same shape — velocity prediction for Rectified Flow ODE. + """ + + def __init__(self, cfg: IrodoriDiTConfig): + super().__init__() + self.cfg = cfg + self.head_dim = cfg.model_dim // cfg.num_heads + + self.text_encoder = TextEncoder( + vocab_size=cfg.text_vocab_size, + dim=cfg.text_dim, + heads=cfg.text_heads, + num_layers=cfg.text_layers, + mlp_ratio=cfg.text_mlp_ratio_resolved, + norm_eps=cfg.norm_eps, + ) + # ReferenceLatentEncoder receives patched speaker latents + self.speaker_encoder = ReferenceLatentEncoder( + in_dim=cfg.speaker_patched_latent_dim, + dim=cfg.speaker_dim, + heads=cfg.speaker_heads, + num_layers=cfg.speaker_layers, + mlp_ratio=cfg.speaker_mlp_ratio_resolved, + norm_eps=cfg.norm_eps, + ) + self.text_norm = RMSNorm(cfg.text_dim, eps=cfg.norm_eps) + self.speaker_norm = RMSNorm(cfg.speaker_dim, eps=cfg.norm_eps) + + # Timestep → conditioning embedding (3 × model_dim for shift/scale/gate) + self.cond_module = nn.Sequential( + nn.Linear(cfg.timestep_embed_dim, cfg.model_dim, bias=False), + nn.SiLU(), + nn.Linear(cfg.model_dim, cfg.model_dim, bias=False), + nn.SiLU(), + nn.Linear(cfg.model_dim, cfg.model_dim * 3, bias=False), + ) + + self.in_proj = nn.Linear(cfg.patched_latent_dim, cfg.model_dim, bias=True) + mlp_hidden = int(cfg.model_dim * cfg.mlp_ratio) + self.blocks = [ + DiffusionBlock( + dim=cfg.model_dim, + heads=cfg.num_heads, + mlp_hidden_dim=mlp_hidden, + text_ctx_dim=cfg.text_dim, + speaker_ctx_dim=cfg.speaker_dim, + adaln_rank=cfg.adaln_rank, + norm_eps=cfg.norm_eps, + ) + for _ in range(cfg.num_layers) + ] + self.out_norm = RMSNorm(cfg.model_dim, eps=cfg.norm_eps) + self.out_proj = nn.Linear(cfg.model_dim, cfg.patched_latent_dim, bias=True) + + # ------------------------------------------------------------------ + # Condition encoding (text + speaker) — can be cached across steps + # ------------------------------------------------------------------ + + def encode_conditions( + self, + text_input_ids: mx.array, + text_mask: mx.array, + ref_latent: mx.array, + ref_mask: mx.array, + ) -> Tuple[mx.array, mx.array, mx.array, mx.array]: + """ + Encode text and reference latent into conditioning states. + Patches the reference latent before encoding. + Returns (text_state, text_mask, speaker_state, speaker_mask). + """ + ref_latent, ref_mask = patch_sequence_with_mask( + ref_latent, ref_mask, self.cfg.speaker_patch_size + ) + text_state = self.text_norm(self.text_encoder(text_input_ids, text_mask)) + speaker_state = self.speaker_norm(self.speaker_encoder(ref_latent, ref_mask)) + return text_state, text_mask, speaker_state, ref_mask + + def build_kv_cache( + self, + text_state: mx.array, + speaker_state: mx.array, + ) -> Tuple[List[KVCache], List[KVCache]]: + """Pre-compute per-layer text/speaker KV projections for fast sampling.""" + kv_text = [ + block.attention.get_kv_cache_text(text_state) for block in self.blocks + ] + kv_speaker = [ + block.attention.get_kv_cache_speaker(speaker_state) for block in self.blocks + ] + return kv_text, kv_speaker + + # ------------------------------------------------------------------ + # Forward (with pre-encoded conditions) + # ------------------------------------------------------------------ + + def forward_with_conditions( + self, + x_t: mx.array, + t: mx.array, + text_state: mx.array, + text_mask: mx.array, + speaker_state: mx.array, + speaker_mask: mx.array, + kv_text: Optional[List[KVCache]] = None, + kv_speaker: Optional[List[KVCache]] = None, + start_pos: int = 0, + ) -> mx.array: + t_embed = get_timestep_embedding(t, self.cfg.timestep_embed_dim).astype( + x_t.dtype + ) + cond_embed = self.cond_module(t_embed)[:, None, :] # (B, 1, 3*model_dim) + + x = self.in_proj(x_t) + freqs_cis = precompute_freqs_cis(self.head_dim, start_pos + x.shape[1]) + + for i, block in enumerate(self.blocks): + kv_t = ( + kv_text[i] + if kv_text is not None + else block.attention.get_kv_cache_text(text_state) + ) + kv_s = ( + kv_speaker[i] + if kv_speaker is not None + else block.attention.get_kv_cache_speaker(speaker_state) + ) + x = block( + x, + cond_embed, + text_mask, + speaker_mask, + freqs_cis, + kv_t, + kv_s, + start_pos, + ) + + x = self.out_norm(x) + return self.out_proj(x).astype(mx.float32) + + # ------------------------------------------------------------------ + # Full forward (encode conditions + denoise) + # ------------------------------------------------------------------ + + def __call__( + self, + x_t: mx.array, + t: mx.array, + text_input_ids: mx.array, + text_mask: mx.array, + ref_latent: mx.array, + ref_mask: mx.array, + ) -> mx.array: + text_state, text_mask, speaker_state, speaker_mask = self.encode_conditions( + text_input_ids, text_mask, ref_latent, ref_mask + ) + return self.forward_with_conditions( + x_t, t, text_state, text_mask, speaker_state, speaker_mask + ) diff --git a/mlx_audio/tts/models/irodori_tts/sampling.py b/mlx_audio/tts/models/irodori_tts/sampling.py new file mode 100644 index 00000000..7834581e --- /dev/null +++ b/mlx_audio/tts/models/irodori_tts/sampling.py @@ -0,0 +1,427 @@ +from __future__ import annotations + +from typing import List, Optional, Tuple + +import mlx.core as mx +import numpy as np + +from .model import IrodoriDiT + +KVCache = List[Tuple[mx.array, mx.array]] + + +# --------------------------------------------------------------------------- +# KV cache helpers +# --------------------------------------------------------------------------- + + +def _concat_kv_caches(*caches: KVCache) -> KVCache: + """Concatenate KV caches from multiple conditions along the batch axis.""" + result: KVCache = [] + for i in range(len(caches[0])): + k = mx.concatenate([c[i][0] for c in caches], axis=0) + v = mx.concatenate([c[i][1] for c in caches], axis=0) + result.append((k, v)) + return result + + +def _scale_kv_cache( + cache: KVCache, + scale: float, + speaker_only: bool = True, + max_layers: Optional[int] = None, +) -> KVCache: + """Return a new KV cache with speaker KVs scaled (immutable, MLX-friendly).""" + n = len(cache) if max_layers is None else min(max_layers, len(cache)) + result: KVCache = [] + for i, (k, v) in enumerate(cache): + if i < n: + result.append((k * scale, v * scale)) + else: + result.append((k, v)) + return result + + +# --------------------------------------------------------------------------- +# Score rescaling (optional post-processing of velocity prediction) +# --------------------------------------------------------------------------- + + +def _temporal_score_rescale( + v_pred: mx.array, + x_t: mx.array, + t: float, + rescale_k: float, + rescale_sigma: float, +) -> mx.array: + """Temporal score rescaling from https://arxiv.org/pdf/2510.01184.""" + if t >= 1.0: + return v_pred + one_minus_t = 1.0 - t + snr = (one_minus_t**2) / (t**2) + sigma_sq = rescale_sigma**2 + ratio = (snr * sigma_sq + 1.0) / (snr * sigma_sq / rescale_k + 1.0) + return (ratio * (one_minus_t * v_pred + x_t) - x_t) / one_minus_t + + +# --------------------------------------------------------------------------- +# Main sampler +# --------------------------------------------------------------------------- + + +def sample_euler_cfg( + model: IrodoriDiT, + text_input_ids: mx.array, + text_mask: mx.array, + ref_latent: mx.array, + ref_mask: mx.array, + latent_dim: int, + rng_seed: int = 0, + sequence_length: int = 750, + num_steps: int = 40, + cfg_scale_text: float = 3.0, + cfg_scale_speaker: float = 5.0, + cfg_guidance_mode: str = "independent", + cfg_scale: Optional[float] = None, + cfg_min_t: float = 0.5, + cfg_max_t: float = 1.0, + truncation_factor: Optional[float] = None, + rescale_k: Optional[float] = None, + rescale_sigma: Optional[float] = None, + context_kv_cache: bool = True, + speaker_kv_scale: Optional[float] = None, + speaker_kv_min_t: Optional[float] = None, + speaker_kv_max_layers: Optional[int] = None, + **_ignored, +) -> mx.array: + """ + Euler sampler for Rectified Flow ODE with Classifier-Free Guidance. + + Supports three CFG modes: + independent : text and speaker guidance computed in a single 3x-batch forward pass. + joint : single combined unconditional (both text and speaker zeroed). + alternating : text-uncond and speaker-uncond alternate each step. + + Returns latent of shape (batch, sequence_length, latent_dim). + """ + # Backward-compat: single cfg_scale overrides both + if cfg_scale is not None: + cfg_scale_text = float(cfg_scale) + cfg_scale_speaker = float(cfg_scale) + + cfg_guidance_mode = cfg_guidance_mode.strip().lower() + if cfg_guidance_mode not in {"independent", "joint", "alternating"}: + raise ValueError( + f"Unknown cfg_guidance_mode={cfg_guidance_mode!r}. " + "Expected: independent | joint | alternating" + ) + + batch_size = text_input_ids.shape[0] + has_text_cfg = cfg_scale_text > 0 + has_speaker_cfg = cfg_scale_speaker > 0 + + # ---- encode conditions once ---- + text_state_cond, text_mask_cond, speaker_state_cond, speaker_mask_cond = ( + model.encode_conditions( + text_input_ids=text_input_ids, + text_mask=text_mask, + ref_latent=ref_latent, + ref_mask=ref_mask, + ) + ) + mx.eval(text_state_cond, speaker_state_cond) + + # unconditioned states: zero arrays of same shape + # (TextEncoder/SpeakerEncoder zero-masks any position, so feeding zero mask + # gives zero state; using explicit zeros is equivalent and avoids recomputation) + text_state_uncond = mx.zeros_like(text_state_cond) + text_mask_uncond = mx.zeros_like(text_mask_cond) + speaker_state_uncond = mx.zeros_like(speaker_state_cond) + speaker_mask_uncond = mx.zeros_like(speaker_mask_cond) + + # ---- build KV caches ---- + use_kv_cache = context_kv_cache or (speaker_kv_scale is not None) + + kv_text_cond: Optional[KVCache] = None + kv_speaker_cond: Optional[KVCache] = None + kv_text_cfg: Optional[KVCache] = None + kv_speaker_cfg: Optional[KVCache] = None + # extra caches for joint/alternating + kv_text_uncond_joint: Optional[KVCache] = None + kv_speaker_uncond_joint: Optional[KVCache] = None + kv_text_uncond_alt: Optional[KVCache] = None + kv_speaker_uncond_alt: Optional[KVCache] = None + + if use_kv_cache: + kv_text_cond, kv_speaker_cond = model.build_kv_cache( + text_state_cond, speaker_state_cond + ) + if speaker_kv_scale is not None: + kv_speaker_cond = _scale_kv_cache( + kv_speaker_cond, speaker_kv_scale, max_layers=speaker_kv_max_layers + ) + + if cfg_guidance_mode == "independent": + if has_text_cfg and has_speaker_cfg: + # batch order: [cond, text-uncond, speaker-uncond] + kv_text_cfg = _concat_kv_caches( + kv_text_cond, kv_text_cond, kv_text_cond + ) + kv_speaker_cfg = _concat_kv_caches( + kv_speaker_cond, kv_speaker_cond, kv_speaker_cond + ) + elif has_text_cfg: + kv_text_cfg = _concat_kv_caches(kv_text_cond, kv_text_cond) + kv_speaker_cfg = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond) + elif has_speaker_cfg: + kv_text_cfg = _concat_kv_caches(kv_text_cond, kv_text_cond) + kv_speaker_cfg = _concat_kv_caches(kv_speaker_cond, kv_speaker_cond) + + elif cfg_guidance_mode == "joint": + if has_text_cfg or has_speaker_cfg: + kv_text_uncond_joint, kv_speaker_uncond_joint = model.build_kv_cache( + text_state_uncond, speaker_state_uncond + ) + + elif cfg_guidance_mode == "alternating": + if has_text_cfg: + kv_text_uncond_alt, _ = model.build_kv_cache( + text_state_uncond, speaker_state_cond + ) + kv_text_uncond_alt = kv_text_uncond_alt + _, kv_speaker_uncond_alt_sp = model.build_kv_cache( + text_state_uncond, speaker_state_cond + ) + if has_speaker_cfg: + _, kv_speaker_uncond_alt = model.build_kv_cache( + text_state_cond, speaker_state_uncond + ) + if speaker_kv_scale is not None: + kv_speaker_uncond_alt = _scale_kv_cache( + kv_speaker_uncond_alt, + speaker_kv_scale, + max_layers=speaker_kv_max_layers, + ) + + mx.eval(kv_text_cond, kv_speaker_cond) + + # ---- initial noise ---- + mx.random.seed(rng_seed) + init_scale = 0.999 + x_t = mx.random.normal((batch_size, sequence_length, latent_dim)) + if truncation_factor is not None: + x_t = x_t * float(truncation_factor) + + t_schedule = np.linspace(1.0 * init_scale, 0.0, num_steps + 1, dtype=np.float32) + + speaker_kv_active = speaker_kv_scale is not None + + # ---- Euler steps ---- + for i in range(num_steps): + t = float(t_schedule[i]) + t_next = float(t_schedule[i + 1]) + t_arr = mx.full((batch_size,), t, dtype=mx.float32) + use_cfg = (has_text_cfg or has_speaker_cfg) and (cfg_min_t <= t <= cfg_max_t) + + if use_cfg: + if cfg_guidance_mode == "independent": + if has_text_cfg and has_speaker_cfg: + # 3x batch: [cond, text-uncond, speaker-uncond] + x_cfg = mx.concatenate([x_t, x_t, x_t], axis=0) + t_cfg = mx.full((batch_size * 3,), t, dtype=mx.float32) + text_mask_cfg = mx.concatenate( + [text_mask_cond, text_mask_uncond, text_mask_cond], axis=0 + ) + speaker_mask_cfg = mx.concatenate( + [speaker_mask_cond, speaker_mask_cond, speaker_mask_uncond], + axis=0, + ) + v_out = model.forward_with_conditions( + x_t=x_cfg, + t=t_cfg, + text_state=mx.concatenate( + [text_state_cond, text_state_uncond, text_state_cond], + axis=0, + ), + text_mask=text_mask_cfg, + speaker_state=mx.concatenate( + [ + speaker_state_cond, + speaker_state_cond, + speaker_state_uncond, + ], + axis=0, + ), + speaker_mask=speaker_mask_cfg, + kv_text=kv_text_cfg, + kv_speaker=kv_speaker_cfg, + ) + v_cond, v_uncond_text, v_uncond_speaker = mx.split(v_out, 3, axis=0) + v_pred = ( + v_cond + + cfg_scale_text * (v_cond - v_uncond_text) + + cfg_scale_speaker * (v_cond - v_uncond_speaker) + ) + + elif has_text_cfg: + x_cfg = mx.concatenate([x_t, x_t], axis=0) + t_cfg = mx.full((batch_size * 2,), t, dtype=mx.float32) + v_out = model.forward_with_conditions( + x_t=x_cfg, + t=t_cfg, + text_state=mx.concatenate( + [text_state_cond, text_state_uncond], axis=0 + ), + text_mask=mx.concatenate( + [text_mask_cond, text_mask_uncond], axis=0 + ), + speaker_state=mx.concatenate( + [speaker_state_cond, speaker_state_cond], axis=0 + ), + speaker_mask=mx.concatenate( + [speaker_mask_cond, speaker_mask_cond], axis=0 + ), + kv_text=kv_text_cfg, + kv_speaker=kv_speaker_cfg, + ) + v_cond, v_uncond_text = mx.split(v_out, 2, axis=0) + v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond_text) + + else: # has_speaker_cfg only + x_cfg = mx.concatenate([x_t, x_t], axis=0) + t_cfg = mx.full((batch_size * 2,), t, dtype=mx.float32) + v_out = model.forward_with_conditions( + x_t=x_cfg, + t=t_cfg, + text_state=mx.concatenate( + [text_state_cond, text_state_cond], axis=0 + ), + text_mask=mx.concatenate( + [text_mask_cond, text_mask_cond], axis=0 + ), + speaker_state=mx.concatenate( + [speaker_state_cond, speaker_state_uncond], axis=0 + ), + speaker_mask=mx.concatenate( + [speaker_mask_cond, speaker_mask_uncond], axis=0 + ), + kv_text=kv_text_cfg, + kv_speaker=kv_speaker_cfg, + ) + v_cond, v_uncond_speaker = mx.split(v_out, 2, axis=0) + v_pred = v_cond + cfg_scale_speaker * (v_cond - v_uncond_speaker) + + elif cfg_guidance_mode == "joint": + if has_text_cfg and has_speaker_cfg: + if abs(cfg_scale_text - cfg_scale_speaker) > 1e-6: + raise ValueError( + "cfg_guidance_mode='joint' requires equal text/speaker scales. " + "Use cfg_scale or set both to the same value." + ) + joint_scale = cfg_scale_text + else: + joint_scale = cfg_scale_text if has_text_cfg else cfg_scale_speaker + + v_cond = model.forward_with_conditions( + x_t=x_t, + t=t_arr, + text_state=text_state_cond, + text_mask=text_mask_cond, + speaker_state=speaker_state_cond, + speaker_mask=speaker_mask_cond, + kv_text=kv_text_cond, + kv_speaker=kv_speaker_cond, + ) + v_uncond = model.forward_with_conditions( + x_t=x_t, + t=t_arr, + text_state=text_state_uncond, + text_mask=text_mask_uncond, + speaker_state=speaker_state_uncond, + speaker_mask=speaker_mask_uncond, + kv_text=kv_text_uncond_joint, + kv_speaker=kv_speaker_uncond_joint, + ) + v_pred = v_cond + joint_scale * (v_cond - v_uncond) + + else: # alternating + v_cond = model.forward_with_conditions( + x_t=x_t, + t=t_arr, + text_state=text_state_cond, + text_mask=text_mask_cond, + speaker_state=speaker_state_cond, + speaker_mask=speaker_mask_cond, + kv_text=kv_text_cond, + kv_speaker=kv_speaker_cond, + ) + use_text_uncond = (has_text_cfg and has_speaker_cfg and i % 2 == 0) or ( + has_text_cfg and not has_speaker_cfg + ) + if use_text_uncond: + v_uncond = model.forward_with_conditions( + x_t=x_t, + t=t_arr, + text_state=text_state_uncond, + text_mask=text_mask_uncond, + speaker_state=speaker_state_cond, + speaker_mask=speaker_mask_cond, + kv_text=kv_text_uncond_alt, + kv_speaker=kv_speaker_cond, + ) + v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond) + else: + v_uncond = model.forward_with_conditions( + x_t=x_t, + t=t_arr, + text_state=text_state_cond, + text_mask=text_mask_cond, + speaker_state=speaker_state_uncond, + speaker_mask=speaker_mask_uncond, + kv_text=kv_text_cond, + kv_speaker=kv_speaker_uncond_alt, + ) + v_pred = v_cond + cfg_scale_speaker * (v_cond - v_uncond) + + else: + # no CFG this step + v_pred = model.forward_with_conditions( + x_t=x_t, + t=t_arr, + text_state=text_state_cond, + text_mask=text_mask_cond, + speaker_state=speaker_state_cond, + speaker_mask=speaker_mask_cond, + kv_text=kv_text_cond, + kv_speaker=kv_speaker_cond, + ) + + # optional temporal score rescaling + if rescale_k is not None and rescale_sigma is not None: + v_pred = _temporal_score_rescale(v_pred, x_t, t, rescale_k, rescale_sigma) + + # speaker KV scale rollback at threshold + if ( + speaker_kv_active + and speaker_kv_min_t is not None + and t_next < speaker_kv_min_t <= t + ): + inv = 1.0 / speaker_kv_scale + kv_speaker_cond = _scale_kv_cache( + kv_speaker_cond, inv, max_layers=speaker_kv_max_layers + ) + if kv_speaker_cfg is not None: + kv_speaker_cfg = _concat_kv_caches( + kv_speaker_cond, kv_speaker_cond, kv_speaker_cond + ) + if kv_speaker_uncond_alt is not None: + kv_speaker_uncond_alt = _scale_kv_cache( + kv_speaker_uncond_alt, inv, max_layers=speaker_kv_max_layers + ) + speaker_kv_active = False + + # Euler update: x_{t-dt} = x_t + v * (t_next - t) + x_t = x_t + v_pred * (t_next - t) + mx.eval(x_t) + + return x_t diff --git a/mlx_audio/tts/models/irodori_tts/text.py b/mlx_audio/tts/models/irodori_tts/text.py new file mode 100644 index 00000000..a82c3453 --- /dev/null +++ b/mlx_audio/tts/models/irodori_tts/text.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import re +from typing import Optional + +import mlx.core as mx +import numpy as np + +# --------------------------------------------------------------------------- +# Japanese text normalisation +# Ported from Irodori-TTS/irodori_tts/text_normalization.py (pure Python). +# --------------------------------------------------------------------------- + +_REPLACE_MAP: dict[str, str] = { + r"\t": "", + r"\[n\]": "", + r" ": "", # narrow no-break space (U+202F) / ideographic space handled below + r" ": "", # ideographic space + r"[;▼♀♂《》≪≫①②③④⑤⑥]": "", + r"[\u02d7\u2010-\u2015\u2043\u2212\u23af\u23e4\u2500\u2501\u2e3a\u2e3b]": "", + r"[\uff5e\u301C]": "ー", + r"?": "?", + r"!": "!", + r"[●◯〇]": "○", + r"♥": "♡", +} + +# Fullwidth A-Z a-z → halfwidth +_FULLWIDTH_ALPHA_TO_HALFWIDTH = str.maketrans( + { + chr(full): chr(half) + for full, half in zip( + list(range(0xFF21, 0xFF3B)) + list(range(0xFF41, 0xFF5B)), + list(range(0x41, 0x5B)) + list(range(0x61, 0x7B)), + ) + } +) + +# Fullwidth 0-9 → halfwidth +_FULLWIDTH_DIGITS_TO_HALFWIDTH = str.maketrans( + { + chr(full): chr(half) + for full, half in zip(range(0xFF10, 0xFF1A), range(0x30, 0x3A)) + } +) + +# Halfwidth katakana → fullwidth katakana +_HW_KANA = "ヲァィゥェォャュョッーアイウエオカキクケコサシスセソタチツテトナニヌネノハヒフヘホマミムメモヤユヨラリルレロワン" +_FW_KANA = "ヲァィゥェォャュョッーアイウエオカキクケコサシスセソタチツテトナニヌネノハヒフヘホマミムメモヤユヨラリルレロワン" +_HALFWIDTH_KANA_TO_FULLWIDTH = str.maketrans(_HW_KANA, _FW_KANA) + + +def normalize_text(text: str) -> str: + """ + Normalise Japanese text for TTS input. + - Removes noise characters (tabs, special symbols, etc.) + - Converts fullwidth alphanumerics and digits to halfwidth + - Converts halfwidth katakana to fullwidth + - Strips surrounding brackets and trailing punctuation + """ + for pattern, replacement in _REPLACE_MAP.items(): + text = re.sub(pattern, replacement, text) + + text = text.translate(_FULLWIDTH_ALPHA_TO_HALFWIDTH) + text = text.translate(_FULLWIDTH_DIGITS_TO_HALFWIDTH) + text = text.translate(_HALFWIDTH_KANA_TO_FULLWIDTH) + + # Collapse runs of 3+ ellipses to double + text = re.sub(r"…{3,}", "……", text) + + # Strip surrounding bracket pairs + for open_br, close_br in [ + ("「", "」"), + ("『", "』"), + ("(", ")"), + ("【", "】"), + ("(", ")"), + ]: + if text.startswith(open_br) and text.endswith(close_br): + text = text[1:-1] + + # Strip trailing Japanese sentence-ending punctuation + if text.endswith(("。", "、")): + text = text.rstrip("。、") + + return text + + +# --------------------------------------------------------------------------- +# Tokenisation +# --------------------------------------------------------------------------- + + +def encode_text( + text: str, + tokenizer, + max_length: int, + add_bos: bool = True, +) -> tuple[mx.array, mx.array]: + """ + Tokenise a single text string using a HuggingFace tokenizer. + + Matches Irodori's PretrainedTextTokenizer behaviour: + - special tokens are NOT added by the HF tokenizer + - BOS is prepended manually when add_bos=True + - right-padding to max_length with pad_token_id + + Returns + ------- + input_ids : mx.array shape (1, max_length) int32 + mask : mx.array shape (1, max_length) bool + """ + # Ensure right-padding (tokenizer default may differ) + tokenizer.padding_side = "right" + if tokenizer.pad_token_id is None: + if tokenizer.eos_token_id is not None: + tokenizer.pad_token = tokenizer.eos_token + else: + raise ValueError( + "Tokenizer has no pad_token_id. Set a pad token before inference." + ) + + token_ids: list[int] = tokenizer.encode(text, add_special_tokens=False) + + if add_bos: + if tokenizer.bos_token_id is None: + raise ValueError("Tokenizer has no bos_token_id but add_bos=True.") + token_ids.insert(0, int(tokenizer.bos_token_id)) + + # Truncate + token_ids = token_ids[:max_length] + n = len(token_ids) + + # Pad + pad_id = int(tokenizer.pad_token_id) + padded = token_ids + [pad_id] * (max_length - n) + + ids_np = np.array([padded], dtype=np.int32) + mask_np = np.zeros((1, max_length), dtype=bool) + mask_np[0, :n] = True + + return mx.array(ids_np), mx.array(mask_np) diff --git a/mlx_audio/tts/tests/test_models.py b/mlx_audio/tts/tests/test_models.py index 1302e1d1..da07b076 100644 --- a/mlx_audio/tts/tests/test_models.py +++ b/mlx_audio/tts/tests/test_models.py @@ -3562,5 +3562,332 @@ def test_config_from_dict_handles_upstream_nested_shape(self): self.assertEqual(config.semantic_start_token_id, 1000) +# --------------------------------------------------------------------------- +# Irodori-TTS helpers +# --------------------------------------------------------------------------- + + +class _MockTokenizer: + """Minimal HuggingFace-style tokenizer stub for Irodori-TTS tests.""" + + bos_token_id = 1 + eos_token_id = 2 + pad_token_id = 0 + padding_side = "right" + + def encode(self, text, add_special_tokens=False): + return [3, 4, 5] + + +class _FakeDACVAE: + """DACVAE stub that matches the real API shapes.""" + + def __init__(self, latent_dim: int = 8, downsample_factor: int = 1920): + self.latent_dim = latent_dim + self.downsample_factor = downsample_factor + + def encode(self, audio_in: mx.array) -> mx.array: + B = audio_in.shape[0] + T = max(1, int(audio_in.shape[1]) // self.downsample_factor) + return mx.zeros((B, self.latent_dim, T), dtype=mx.float32) + + def decode(self, latent: mx.array) -> mx.array: + B, _D, T = latent.shape + return mx.zeros((B, T * self.downsample_factor, 1), dtype=mx.float32) + + +def _small_irodori_dit_config(**overrides): + from mlx_audio.tts.models.irodori_tts.config import IrodoriDiTConfig + + defaults = dict( + latent_dim=8, + latent_patch_size=1, + model_dim=32, + num_layers=2, + num_heads=4, + mlp_ratio=2.0, + text_mlp_ratio=2.0, + speaker_mlp_ratio=2.0, + text_vocab_size=64, + text_dim=32, + text_layers=1, + text_heads=4, + speaker_dim=32, + speaker_layers=1, + speaker_heads=4, + speaker_patch_size=1, + timestep_embed_dim=16, + adaln_rank=8, + norm_eps=1e-5, + ) + defaults.update(overrides) + return IrodoriDiTConfig(**defaults) + + +def _small_irodori_model_config(**sampler_overrides): + from mlx_audio.tts.models.irodori_tts.config import ModelConfig, SamplerConfig + + sampler_defaults = dict( + num_steps=1, + cfg_scale_text=1.0, + cfg_scale_speaker=1.0, + sequence_length=4, + ) + sampler_defaults.update(sampler_overrides) + return ModelConfig( + dit=_small_irodori_dit_config(), + sampler=SamplerConfig(**sampler_defaults), + ) + + +# --------------------------------------------------------------------------- +# Irodori-TTS test classes +# --------------------------------------------------------------------------- + + +class TestIrodoriNormalizeText(unittest.TestCase): + def test_fullwidth_alpha_to_halfwidth(self): + from mlx_audio.tts.models.irodori_tts.text import normalize_text + + self.assertEqual(normalize_text("Ab"), "Ab") + + def test_fullwidth_digits_to_halfwidth(self): + from mlx_audio.tts.models.irodori_tts.text import normalize_text + + self.assertEqual(normalize_text("123"), "123") + + def test_halfwidth_kana_to_fullwidth(self): + from mlx_audio.tts.models.irodori_tts.text import normalize_text + + self.assertEqual(normalize_text("アイ"), "アイ") + + def test_wave_dash_to_katakana_dash(self): + from mlx_audio.tts.models.irodori_tts.text import normalize_text + + self.assertEqual(normalize_text("ー〜ー"), "ーーー") + + def test_trailing_kuten_stripped(self): + from mlx_audio.tts.models.irodori_tts.text import normalize_text + + result = normalize_text("こんにちは。") + self.assertFalse(result.endswith("。")) + self.assertEqual(result, "こんにちは") + + def test_surrounding_brackets_stripped(self): + from mlx_audio.tts.models.irodori_tts.text import normalize_text + + self.assertEqual(normalize_text("「こんにちは」"), "こんにちは") + + def test_no_change_for_plain_text(self): + from mlx_audio.tts.models.irodori_tts.text import normalize_text + + text = "こんにちは" + self.assertEqual(normalize_text(text), text) + + +class TestIrodoriEncodeText(unittest.TestCase): + def setUp(self): + self.tok = _MockTokenizer() + + def test_output_shapes(self): + from mlx_audio.tts.models.irodori_tts.text import encode_text + + ids, mask = encode_text("hello", self.tok, max_length=10, add_bos=True) + self.assertEqual(tuple(ids.shape), (1, 10)) + self.assertEqual(tuple(mask.shape), (1, 10)) + + def test_bos_prepended(self): + from mlx_audio.tts.models.irodori_tts.text import encode_text + + ids, mask = encode_text("hello", self.tok, max_length=10, add_bos=True) + self.assertEqual(int(ids[0, 0]), self.tok.bos_token_id) + + def test_no_bos(self): + from mlx_audio.tts.models.irodori_tts.text import encode_text + + ids, _ = encode_text("hello", self.tok, max_length=10, add_bos=False) + self.assertEqual(int(ids[0, 0]), 3) + + def test_padding(self): + from mlx_audio.tts.models.irodori_tts.text import encode_text + + ids, mask = encode_text("hello", self.tok, max_length=10, add_bos=True) + for i in range(4, 10): + self.assertEqual(int(ids[0, i]), self.tok.pad_token_id) + self.assertFalse(bool(mask[0, i])) + + def test_mask_true_for_real_tokens(self): + from mlx_audio.tts.models.irodori_tts.text import encode_text + + ids, mask = encode_text("hello", self.tok, max_length=10, add_bos=True) + for i in range(4): + self.assertTrue(bool(mask[0, i])) + + def test_truncation(self): + from mlx_audio.tts.models.irodori_tts.text import encode_text + + ids, mask = encode_text("hello", self.tok, max_length=2, add_bos=True) + self.assertEqual(tuple(ids.shape), (1, 2)) + + +class TestIrodoriDiTShapes(unittest.TestCase): + def setUp(self): + from mlx_audio.tts.models.irodori_tts.model import IrodoriDiT + + self.cfg = _small_irodori_dit_config() + self.model = IrodoriDiT(self.cfg) + + def test_full_forward_shape(self): + B, S = 1, 6 + x_t = mx.random.normal((B, S, self.cfg.patched_latent_dim)) + t = mx.array([0.5], dtype=mx.float32) + text_ids = mx.zeros((B, 5), dtype=mx.int32) + text_mask = mx.ones((B, 5), dtype=mx.bool_) + ref_latent = mx.random.normal((B, 8, self.cfg.latent_dim)) + ref_mask = mx.ones((B, 8), dtype=mx.bool_) + + out = self.model(x_t, t, text_ids, text_mask, ref_latent, ref_mask) + mx.eval(out) + self.assertEqual(tuple(out.shape), (B, S, self.cfg.patched_latent_dim)) + + def test_encode_conditions_shapes(self): + B = 1 + text_ids = mx.zeros((B, 5), dtype=mx.int32) + text_mask = mx.ones((B, 5), dtype=mx.bool_) + ref_latent = mx.random.normal((B, 8, self.cfg.latent_dim)) + ref_mask = mx.ones((B, 8), dtype=mx.bool_) + + t_state, t_mask, s_state, s_mask = self.model.encode_conditions( + text_ids, text_mask, ref_latent, ref_mask + ) + mx.eval(t_state, s_state) + self.assertEqual(tuple(t_state.shape), (B, 5, self.cfg.text_dim)) + self.assertEqual(int(s_state.shape[0]), B) + self.assertEqual(int(s_state.shape[-1]), self.cfg.speaker_dim) + + def test_kv_cache_and_forward_with_conditions(self): + B, S = 1, 4 + text_ids = mx.zeros((B, 5), dtype=mx.int32) + text_mask = mx.ones((B, 5), dtype=mx.bool_) + ref_latent = mx.zeros((B, 8, self.cfg.latent_dim)) + ref_mask = mx.ones((B, 8), dtype=mx.bool_) + + t_state, t_mask, s_state, s_mask = self.model.encode_conditions( + text_ids, text_mask, ref_latent, ref_mask + ) + kv_text, kv_speaker = self.model.build_kv_cache(t_state, s_state) + self.assertEqual(len(kv_text), self.cfg.num_layers) + self.assertEqual(len(kv_speaker), self.cfg.num_layers) + + x_t = mx.random.normal((B, S, self.cfg.patched_latent_dim)) + t = mx.array([0.3], dtype=mx.float32) + out = self.model.forward_with_conditions( + x_t, t, t_state, t_mask, s_state, s_mask, kv_text, kv_speaker + ) + mx.eval(out) + self.assertEqual(tuple(out.shape), (B, S, self.cfg.patched_latent_dim)) + + def test_zero_speaker_latent(self): + B, S = 1, 4 + x_t = mx.random.normal((B, S, self.cfg.patched_latent_dim)) + t = mx.array([1.0], dtype=mx.float32) + text_ids = mx.zeros((B, 5), dtype=mx.int32) + text_mask = mx.ones((B, 5), dtype=mx.bool_) + ref_latent = mx.zeros((B, 1, self.cfg.latent_dim)) + ref_mask = mx.zeros((B, 1), dtype=mx.bool_) + + out = self.model(x_t, t, text_ids, text_mask, ref_latent, ref_mask) + mx.eval(out) + self.assertEqual(tuple(out.shape), (B, S, self.cfg.patched_latent_dim)) + + +class TestIrodoriModelSanitize(unittest.TestCase): + def setUp(self): + from mlx_audio.tts.models.irodori_tts.irodori_tts import Model + + self.model = Model(_small_irodori_model_config()) + + def test_cond_module_key_remapped(self): + weights = {"cond_module.0.weight": mx.zeros((1, 1), dtype=mx.float32)} + sanitized = self.model.sanitize(weights) + self.assertIn("model.cond_module.layers.0.weight", sanitized) + self.assertNotIn("cond_module.0.weight", sanitized) + + def test_model_prefix_added(self): + weights = {"blocks.0.mlp.w1.weight": mx.zeros((1, 1), dtype=mx.float32)} + sanitized = self.model.sanitize(weights) + self.assertIn("model.blocks.0.mlp.w1.weight", sanitized) + + def test_model_prefix_not_doubled(self): + weights = {"model.out_proj.weight": mx.zeros((1, 1), dtype=mx.float32)} + sanitized = self.model.sanitize(weights) + self.assertIn("model.out_proj.weight", sanitized) + self.assertNotIn("model.model.out_proj.weight", sanitized) + + def test_deep_cond_module_key(self): + weights = {"cond_module.2.bias": mx.zeros((1,), dtype=mx.float32)} + sanitized = self.model.sanitize(weights) + self.assertIn("model.cond_module.layers.2.bias", sanitized) + + +class TestIrodoriGenerateSmoke(unittest.TestCase): + def _make_model(self): + from mlx_audio.tts.models.irodori_tts.irodori_tts import Model + + cfg = _small_irodori_model_config() + model = Model(cfg) + model.dacvae = _FakeDACVAE( + latent_dim=cfg.dit.latent_dim, + downsample_factor=cfg.audio_downsample_factor, + ) + model._tokenizer = _MockTokenizer() + return model + + def test_generate_returns_result(self): + model = self._make_model() + results = list(model.generate("こんにちは", rng_seed=0)) + self.assertEqual(len(results), 1) + self.assertEqual(results[0].sample_rate, 48000) + self.assertGreater(results[0].samples, 0) + + def test_generate_with_ref_audio(self): + from mlx_audio.tts.models.irodori_tts.irodori_tts import Model + + cfg = _small_irodori_model_config() + model = Model(cfg) + model.dacvae = _FakeDACVAE( + latent_dim=cfg.dit.latent_dim, + downsample_factor=cfg.audio_downsample_factor, + ) + model._tokenizer = _MockTokenizer() + ref = mx.zeros((1, cfg.audio_downsample_factor * 4), dtype=mx.float32) + results = list(model.generate("テスト", ref_audio=ref, rng_seed=1)) + self.assertEqual(len(results), 1) + self.assertGreater(results[0].samples, 0) + + def test_generate_stream_raises(self): + model = self._make_model() + with self.assertRaises(NotImplementedError): + next(model.generate("hi", stream=True)) + + def test_generate_without_dacvae_raises(self): + from mlx_audio.tts.models.irodori_tts.irodori_tts import Model + + cfg = _small_irodori_model_config() + model = Model(cfg) + model._tokenizer = _MockTokenizer() + with self.assertRaises(ValueError): + next(model.generate("hi")) + + def test_result_fields(self): + model = self._make_model() + result = next(model.generate("テスト", rng_seed=0)) + self.assertIsNotNone(result.audio) + self.assertIsInstance(result.token_count, int) + self.assertGreater(result.token_count, 0) + self.assertIsNotNone(result.audio_duration) + self.assertGreater(result.real_time_factor, 0.0) + + if __name__ == "__main__": unittest.main() diff --git a/mlx_audio/tts/utils.py b/mlx_audio/tts/utils.py index e302c65f..6e9eae41 100644 --- a/mlx_audio/tts/utils.py +++ b/mlx_audio/tts/utils.py @@ -31,6 +31,7 @@ "kitten": "kitten_tts", "echo_tts": "echo_tts", "fish_qwen3_omni": "fish_qwen3_omni", + "irodori_tts": "irodori_tts", } MAX_FILE_SIZE_GB = 5 MODEL_CONVERSION_DTYPES = ["float16", "bfloat16", "float32"]