diff --git a/mlx_audio/codec/models/encodec/__init__.py b/mlx_audio/codec/models/encodec/__init__.py index 3e6ccc2d..5550c6bd 100644 --- a/mlx_audio/codec/models/encodec/__init__.py +++ b/mlx_audio/codec/models/encodec/__init__.py @@ -1 +1 @@ -from .encodec import Encodec +from .encodec import Encodec, EncodecConfig diff --git a/mlx_audio/codec/models/encodec/encodec.py b/mlx_audio/codec/models/encodec/encodec.py index c8747d35..78522f33 100644 --- a/mlx_audio/codec/models/encodec/encodec.py +++ b/mlx_audio/codec/models/encodec/encodec.py @@ -1,6 +1,7 @@ import functools import json import math +from dataclasses import dataclass from pathlib import Path from types import SimpleNamespace from typing import List, Optional, Tuple, Union @@ -11,6 +12,40 @@ from huggingface_hub import snapshot_download +def filter_dataclass_fields(data_dict, dataclass_type): + """Filter a dictionary to only include keys that are fields in the dataclass.""" + valid_fields = {f.name for f in dataclass_type.__dataclass_fields__.values()} + return {k: v for k, v in data_dict.items() if k in valid_fields} + + +@dataclass +class EncodecConfig: + model_type: str = "encodec" + audio_channels: int = 1 + num_filters: int = 32 + kernel_size: int = 7 + num_residual_layers: int = 1 + dilation_growth_rate: int = 2 + codebook_size: int = 1024 + codebook_dim: int = 128 + hidden_size: int = 128 + num_lstm_layers: int = 2 + residual_kernel_size: int = 3 + use_causal_conv: bool = True + normalize: bool = False + pad_mode: str = "reflect" + norm_type: str = "weight_norm" + last_kernel_size: int = 7 + trim_right_ratio: float = 1.0 + compress: int = 2 + upsampling_ratios: List[int] = None + target_bandwidths: List[float] = None + sampling_rate: int = 24000 + chunk_length_s: Optional[float] = None + overlap: Optional[float] = None + architectures: List[str] = None + + def preprocess_audio( raw_audio: Union[mx.array, List[mx.array]], sampling_rate: int = 24000, @@ -513,7 +548,7 @@ def decode(self, codes: mx.array) -> mx.array: class Encodec(nn.Module): def __init__(self, config): super().__init__() - self.config = SimpleNamespace(**config) + self.config = config self.encoder = EncodecEncoder(self.config) self.decoder = EncodecDecoder(self.config) self.quantizer = EncodecResidualVectorQuantizer(self.config) @@ -689,6 +724,8 @@ def from_pretrained(cls, path_or_repo: str): with open(path / "config.json", "r") as f: config = json.load(f) + filtered_config = filter_dataclass_fields(config, EncodecConfig) + config = EncodecConfig(**filtered_config) model = cls(config) model.load_weights(str(path / "model.safetensors")) processor = functools.partial( diff --git a/mlx_audio/codec/tests/test_encodec.py b/mlx_audio/codec/tests/test_encodec.py index 83286b57..c91729f2 100644 --- a/mlx_audio/codec/tests/test_encodec.py +++ b/mlx_audio/codec/tests/test_encodec.py @@ -2,33 +2,33 @@ import mlx.core as mx -from ..models.encodec import Encodec - -config = { - "audio_channels": 1, - "chunk_length_s": None, - "codebook_dim": 128, - "codebook_size": 1024, - "compress": 2, - "dilation_growth_rate": 2, - "hidden_size": 128, - "kernel_size": 7, - "last_kernel_size": 7, - "model_type": "encodec", - "norm_type": "weight_norm", - "normalize": False, - "num_filters": 32, - "num_lstm_layers": 2, - "num_residual_layers": 1, - "overlap": None, - "pad_mode": "reflect", - "residual_kernel_size": 3, - "sampling_rate": 24000, - "target_bandwidths": [1.5, 3.0, 6.0, 12.0, 24.0], - "trim_right_ratio": 1.0, - "upsampling_ratios": [8, 5, 4, 2], - "use_causal_conv": True, -} +from ..models.encodec import Encodec, EncodecConfig + +config = EncodecConfig( + audio_channels=1, + chunk_length_s=None, + codebook_dim=128, + codebook_size=1024, + compress=2, + dilation_growth_rate=2, + hidden_size=128, + kernel_size=7, + last_kernel_size=7, + model_type="encodec", + norm_type="weight_norm", + normalize=False, + num_filters=32, + num_lstm_layers=2, + num_residual_layers=1, + overlap=None, + pad_mode="reflect", + residual_kernel_size=3, + sampling_rate=24000, + target_bandwidths=[1.5, 3.0, 6.0, 12.0, 24.0], + trim_right_ratio=1.0, + upsampling_ratios=[8, 5, 4, 2], + use_causal_conv=True, +) class TesEncodec(unittest.TestCase): diff --git a/mlx_audio/tts/generate.py b/mlx_audio/tts/generate.py index b61b51ad..5ad49b89 100644 --- a/mlx_audio/tts/generate.py +++ b/mlx_audio/tts/generate.py @@ -24,7 +24,8 @@ def generate_audio( join_audio: bool = False, play: bool = False, verbose: bool = True, - from_cli: bool = False, + temperature: float = 0.7, + **kwargs, ) -> None: """ Generates audio from text using a specified TTS model. @@ -85,6 +86,7 @@ def generate_audio( lang_code=lang_code, ref_audio=ref_audio, ref_text=ref_text, + temperature=temperature, verbose=True, ) @@ -154,7 +156,7 @@ def parse_args(): default=None, help="Text to generate (leave blank to input via stdin)", ) - parser.add_argument("--voice", type=str, default="af_heart", help="Voice name") + parser.add_argument("--voice", type=str, default=None, help="Voice name") parser.add_argument("--speed", type=float, default=1.0, help="Speed of the audio") parser.add_argument("--lang_code", type=str, default="a", help="Language code") parser.add_argument( @@ -177,6 +179,9 @@ def parse_args(): parser.add_argument( "--ref_text", type=str, default=None, help="Caption for reference audio" ) + parser.add_argument( + "--temperature", type=float, default=0.7, help="Temperature for the model" + ) args = parser.parse_args() diff --git a/mlx_audio/tts/models/bark/__init__.py b/mlx_audio/tts/models/bark/__init__.py new file mode 100644 index 00000000..46d4a37d --- /dev/null +++ b/mlx_audio/tts/models/bark/__init__.py @@ -0,0 +1,4 @@ +from .bark import Model, ModelConfig +from .pipeline import Pipeline + +__all__ = ["Model", "Pipeline", "ModelConfig"] diff --git a/mlx_audio/tts/models/bark/bark.py b/mlx_audio/tts/models/bark/bark.py new file mode 100644 index 00000000..00b3632e --- /dev/null +++ b/mlx_audio/tts/models/bark/bark.py @@ -0,0 +1,519 @@ +import argparse +import glob +import math +import time +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +import tqdm +from mlx.utils import tree_map, tree_unflatten +from mlx_lm.models.base import create_causal_mask +from scipy.io.wavfile import write as write_wav +from transformers import BertTokenizer + +from ..base import BaseModelArgs, GenerationResult +from .pipeline import Pipeline + +mx.random.seed(42) + +TEXT_ENCODING_OFFSET = 10_048 +SEMANTIC_PAD_TOKEN = 10_000 +TEXT_PAD_TOKEN = 129595 +SEMANTIC_INFER_TOKEN = 129_599 + +CONTEXT_WINDOW_SIZE = 1024 + +SEMANTIC_RATE_HZ = 49.9 +SEMANTIC_VOCAB_SIZE = 10_000 + +CODEBOOK_SIZE = 1024 +N_COARSE_CODEBOOKS = 2 +N_FINE_CODEBOOKS = 8 +COARSE_RATE_HZ = 75 +COARSE_SEMANTIC_PAD_TOKEN = 12_048 +COARSE_INFER_TOKEN = 12_050 +SAMPLE_RATE = 24_000 + + +def filter_dataclass_fields(data_dict, dataclass_type): + """Filter a dictionary to only include keys that are fields in the dataclass.""" + valid_fields = {f.name for f in dataclass_type.__dataclass_fields__.values()} + return {k: v for k, v in data_dict.items() if k in valid_fields} + + +@dataclass +class SemanticConfig(BaseModelArgs): + bad_words_ids: list[list[int]] = None + block_size: int = 1024 + input_vocab_size: int = 129600 + output_vocab_size: int = 129600 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + bias: bool = False + model_type: str = "semantic" + dropout: float = 0.0 + architectures: list[str] = None + + +@dataclass +class CoarseAcousticsConfig(BaseModelArgs): + block_size: int = 1024 + input_vocab_size: int = 12096 + output_vocab_size: int = 12096 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + bias: bool = False + model_type: str = "coarse_acoustics" + dropout: float = 0.0 + + +@dataclass +class FineAcousticsConfig(BaseModelArgs): + block_size: int = 1024 + input_vocab_size: int = 1056 + output_vocab_size: int = 1056 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + bias: bool = False + model_type: str = "fine_acoustics" + n_codes_total: int = 8 + n_codes_given: int = 1 + dropout: float = 0.0 + + +@dataclass +class CodecConfig(BaseModelArgs): + model_type: str = "codec" + sample_rate: int = 24000 + target_bandwidth: float = 6.0 + + +@dataclass +class ModelConfig(BaseModelArgs): + semantic_config: SemanticConfig + coarse_acoustics_config: CoarseAcousticsConfig + fine_acoustics_config: FineAcousticsConfig + codec_config: CodecConfig + block_size: int = 1024 + input_vocab_size: int = 10_048 + output_vocab_size: int = 10_048 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.0 + bias: bool = False + n_codes_total: Optional[int] = None + n_codes_given: Optional[int] = None + model_size: str = "base" + model_type: str = "bark" + initializer_range: float = 0.02 + codec_path: str = "mlx-community/encodec-24khz-float32" + + +class LayerNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5, bias: bool = True): + super().__init__() + self.bias = mx.zeros((dims,)) if bias else None + self.weight = mx.ones((dims,)) + self.dims = dims + self.eps = eps + + def __call__(self, x): + mean = mx.mean(x, axis=-1, keepdims=True) + var = mx.var(x, axis=-1, keepdims=True) + x = (x - mean) * mx.rsqrt(var + self.eps) + if self.bias is not None: + x = x * self.weight + self.bias + else: + x = x * self.weight + return x + + +class CausalSelfAttention(nn.Module): + def __init__( + self, args: Union[SemanticConfig, CoarseAcousticsConfig, FineAcousticsConfig] + ): + super().__init__() + self.att_proj = nn.Linear(args.n_embd, 3 * args.n_embd, bias=args.bias) + self.out_proj = nn.Linear(args.n_embd, args.n_embd, bias=args.bias) + self.attn_dropout = nn.Dropout(args.dropout) + self.resid_dropout = nn.Dropout(args.dropout) + self.n_head = args.n_head + self.n_embd = args.n_embd + self.dropout = args.dropout + self.bias = ( + mx.tril(mx.ones([args.block_size, args.block_size])) + .reshape(1, 1, args.block_size, args.block_size) + .astype(mx.float32) + ) + + def __call__(self, x, past_kv=None, use_cache=False): + B, T, C = x.shape + query, key, value = mx.split(self.att_proj(x), 3, axis=2) + key = key.reshape(B, T, self.n_head, C // self.n_head).transpose(0, 2, 1, 3) + query = query.reshape(B, T, self.n_head, C // self.n_head).transpose(0, 2, 1, 3) + value = value.reshape(B, T, self.n_head, C // self.n_head).transpose(0, 2, 1, 3) + if past_kv is not None: + past_key, past_value = past_kv + key = mx.concatenate([past_key, key], axis=-2) + value = mx.concatenate([past_value, value], axis=-2) + + FULL_T = key.shape[-2] + if use_cache is True: + present = (key, value) + else: + present = None + + y = mx.fast.scaled_dot_product_attention( + query, + key, + value, + scale=1.0 / math.sqrt(key.shape[3]), + mask=self.bias[:, :, FULL_T - T : FULL_T, :FULL_T], + ) + y = self.attn_dropout(y) + y = y.transpose(0, 2, 1, 3).reshape(B, T, C) + y = self.resid_dropout(self.out_proj(y)) + return (y, present) + + +class NonCausalSelfAttention(nn.Module): + def __init__( + self, args: Union[SemanticConfig, CoarseAcousticsConfig, FineAcousticsConfig] + ): + super().__init__() + self.att_proj = nn.Linear(args.n_embd, 3 * args.n_embd, bias=args.bias) + self.out_proj = nn.Linear(args.n_embd, args.n_embd, bias=args.bias) + self.attn_dropout = nn.Dropout(args.dropout) + self.resid_dropout = nn.Dropout(args.dropout) + self.n_head = args.n_head + self.n_embd = args.n_embd + self.dropout = args.dropout + + def __call__(self, x): + B, T, C = x.shape + query, key, value = mx.split(self.att_proj(x), 3, axis=2) + key = key.reshape(B, T, self.n_head, C // self.n_head).transpose(0, 2, 1, 3) + query = query.reshape(B, T, self.n_head, C // self.n_head).transpose(0, 2, 1, 3) + value = value.reshape(B, T, self.n_head, C // self.n_head).transpose(0, 2, 1, 3) + + y = mx.fast.scaled_dot_product_attention( + query, key, value, scale=1.0 / math.sqrt(key.shape[3]) + ) + y = self.attn_dropout(y) + y = y.transpose(0, 2, 1, 3).reshape(B, T, C) + y = self.resid_dropout(self.out_proj(y)) + return y + + +class MLP(nn.Module): + def __init__( + self, args: Union[SemanticConfig, CoarseAcousticsConfig, FineAcousticsConfig] + ): + super().__init__() + + self.in_proj = nn.Linear(args.n_embd, 4 * args.n_embd, bias=False) + self.out_proj = nn.Linear(4 * args.n_embd, args.n_embd, bias=False) + self.gelu = nn.GELU() + self.dropout = nn.Dropout(args.dropout) + + def __call__(self, x: mx.array) -> mx.array: + x = self.in_proj(x) + x = self.gelu(x) + x = self.out_proj(x) + x = self.dropout(x) + return x + + +class Block(nn.Module): + def __init__( + self, args: Union[SemanticConfig, CoarseAcousticsConfig], layer_idx: int = 0 + ): + super().__init__() + self.args = args + self.layernorm_1 = LayerNorm(args.n_embd, bias=False) + self.attn = CausalSelfAttention(args) + self.layernorm_2 = LayerNorm(args.n_embd, bias=False) + self.mlp = MLP(args) + self.layer_idx = layer_idx + + def __call__(self, x: mx.array, past_kv=None, use_cache=False): + attn_output, prev_kvs = self.attn( + self.layernorm_1(x), past_kv=past_kv, use_cache=use_cache + ) + x = x + attn_output + x = x + self.mlp(self.layernorm_2(x)) + return (x, prev_kvs) + + +class FineBlock(nn.Module): + def __init__(self, args: FineAcousticsConfig): + super().__init__() + self.args = args + self.layernorm_1 = nn.LayerNorm(args.n_embd) + self.attn = NonCausalSelfAttention(args) + self.layernorm_2 = nn.LayerNorm(args.n_embd) + self.mlp = MLP(args) + + def __call__(self, x: mx.array): + x = x + self.attn(self.layernorm_1(x)) + x = x + self.mlp(self.layernorm_2(x)) + return x + + +class GPT(nn.Module): + def __init__(self, args: Union[SemanticConfig, CoarseAcousticsConfig]): + super().__init__() + self.args = args + self.input_embeds_layer = nn.Embedding(args.input_vocab_size, args.n_embd) + self.position_embeds_layer = nn.Embedding(args.block_size, args.n_embd) + self.drop = nn.Dropout(args.dropout) + self.layers = [Block(args=args) for _ in range(args.n_layer)] + self.layernorm_final = LayerNorm(args.n_embd, bias=False) + self.lm_head = nn.Linear(args.n_embd, args.output_vocab_size, bias=False) + + def __call__( + self, + x: mx.array, + merge_context: bool = False, + past_kv: mx.array = None, + position_ids: mx.array = None, + use_cache: bool = False, + ) -> mx.array: + b, t = x.shape + + if past_kv is not None: + assert t == 1 + tok_emb = self.input_embeds_layer(x) + else: + if merge_context: + assert x.shape[1] >= 256 + 256 + 1 + t = x.shape[1] - 256 + tok_emb = mx.concatenate( + [ + self.input_embeds_layer(x[:, :256]) + + self.input_embeds_layer(x[:, 256 : 256 + 256]), + self.input_embeds_layer(x[:, 256 + 256 :]), + ], + axis=1, + ) + else: + tok_emb = self.input_embeds_layer(x) + + # past length + if past_kv is None: + past_length = 0 + past_kv = tuple([None] * len(self.layers)) + else: + past_length = past_kv[0][0].shape[-2] + + if position_ids is None: + position_ids = mx.arange(past_length, t + past_length) + position_ids = position_ids.reshape(1, -1) # shape (1, t) + + pos_emb = self.position_embeds_layer( + position_ids + ) # position embeddings of shape (1, t, n_embd) + x = self.drop(tok_emb + pos_emb) + + new_kv = () if use_cache else None + + for i, (block, past_layer_kv) in enumerate(zip(self.layers, past_kv)): + x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache) + + if use_cache: + new_kv = new_kv + (kv,) + + x = self.layernorm_final(x) + + logits = self.lm_head( + x[:, -1:, :] + ) # note: using list [-1] to preserve the time dim + + return (logits, new_kv) + + +class FineGPT(nn.Module): + def __init__(self, args: FineAcousticsConfig): + super().__init__() + self.args = args + self.n_codes_total = args.n_codes_total + self.input_embeds_layers = [ + nn.Embedding(args.input_vocab_size, args.n_embd) + for _ in range(args.n_codes_total) + ] + self.position_embeds_layer = nn.Embedding(args.block_size, args.n_embd) + self.drop = nn.Dropout(args.dropout) + self.layers = [FineBlock(args=args) for _ in range(args.n_layer)] + self.layernorm_final = nn.LayerNorm(args.n_embd) + + self.lm_heads = [ + nn.Linear(args.n_embd, args.output_vocab_size, bias=False) + for _ in range(args.n_codes_given, args.n_codes_total) + ] + for i in range(self.n_codes_total - args.n_codes_given): + self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight + + def __call__(self, pred_idx: mx.array, idx: mx.array) -> mx.array: + b, t, codes = idx.shape + assert ( + t <= self.args.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + assert pred_idx > 0, "cannot predict 0th codebook" + assert codes == self.n_codes_total, (b, t, codes) + pos = mx.arange(0, t).astype(mx.int64).reshape(1, t) # shape (1, t) + tok_embs = [ + self.input_embeds_layers[i](idx[:, :, i].astype(mx.int32)).reshape( + b, t, -1, 1 + ) + for i in range(self.n_codes_total) + ] # token embeddings of shape (b, t, n_embd) + tok_emb = mx.concatenate(tok_embs, axis=-1) + pos_emb = self.position_embeds_layer( + pos + ) # position embeddings of shape (1, t, n_embd) + x = tok_emb[:, :, :, : pred_idx + 1].sum(axis=-1) + x = self.drop(x + pos_emb) + for block in self.layers: + x = block(x) + x = self.layernorm_final(x) + + logits = self.lm_heads[pred_idx - self.args.n_codes_given](x) + return logits + + +class Model(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + + # Convert config dictionaries to proper configuration objects if needed + if isinstance(config.semantic_config, dict): + filtered_config = filter_dataclass_fields( + config.semantic_config, SemanticConfig + ) + semantic_config = SemanticConfig(**filtered_config) + else: + semantic_config = config.semantic_config + + if isinstance(config.coarse_acoustics_config, dict): + filtered_config = filter_dataclass_fields( + config.coarse_acoustics_config, CoarseAcousticsConfig + ) + coarse_config = CoarseAcousticsConfig(**filtered_config) + else: + coarse_config = config.coarse_acoustics_config + + if isinstance(config.fine_acoustics_config, dict): + filtered_config = filter_dataclass_fields( + config.fine_acoustics_config, FineAcousticsConfig + ) + fine_config = FineAcousticsConfig(**filtered_config) + else: + fine_config = config.fine_acoustics_config + + self.semantic = GPT(semantic_config) + self.fine_acoustics = FineGPT(fine_config) + self.coarse_acoustics = GPT(coarse_config) + + self.tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased") + + def sanitize(self, weights): + + sanitized_weights = {} + for key, value in weights.items(): + # there's no _orig_mod.transformer + if "_orig_mod.transformer." in key: + key = key.replace("_orig_mod.transformer.", "") + # transformer block mapping + if "h" in key: + layer_count = 24 if self.config.model_size == "large" else 12 + for i in range(layer_count): + prefix = f"h.{i}." + key = key.replace(prefix, f"layers.{i}.") + + # lm_head + if "lm_head" in key: + key = key.replace("_orig_mod.", "") + + if "codec" in key: + pass + else: + sanitized_weights[key] = value + + return sanitized_weights + + def generate(self, text: str, voice: str = None, **kwargs): + pipeline = Pipeline( + model=self, + tokenizer=self.tokenizer, + config=self.config, + ) + + # Track overall generation time + start_time = time.time() + + for segment_idx, (audio, tokens) in enumerate( + pipeline(text, voice=voice, use_kv_caching=True, **kwargs) + ): + # Track per-segment generation time + segment_time = time.time() - start_time + + samples = audio.shape[0] if audio is not None else 0 + assert samples > 0, "No audio generated" + + # Calculate token count + token_count = len(tokens) if tokens is not None else 0 + + # Calculate audio duration in seconds + sample_rate = 24000 # Assuming 24kHz sample rate, adjust if different + audio_duration_seconds = samples / sample_rate * audio.shape[1] + + # Calculate milliseconds per sample + ms_per_sample = ( + 1000 / sample_rate + ) # This gives 0.0417 ms per sample at 24kHz + + # Calculate real-time factor (RTF) + rtf = ( + segment_time / audio_duration_seconds + if audio_duration_seconds > 0 + else 0 + ) + + # Format duration as HH:MM:SS.mmm + duration_mins = int(audio_duration_seconds // 60) + duration_secs = int(audio_duration_seconds % 60) + duration_ms = int((audio_duration_seconds % 1) * 1000) + duration_hours = int(audio_duration_seconds // 3600) + duration_str = f"{duration_hours:02d}:{duration_mins:02d}:{duration_secs:02d}.{duration_ms:03d}" + + yield GenerationResult( + audio=audio[0], + samples=samples, + segment_idx=segment_idx, + token_count=token_count, + audio_duration=duration_str, + real_time_factor=round(rtf, 2), + prompt={ + "tokens": token_count, + "tokens-per-sec": ( + round(token_count / segment_time, 2) if segment_time > 0 else 0 + ), + }, + audio_samples={ + "samples": samples, + "samples-per-sec": ( + round(samples / segment_time, 2) if segment_time > 0 else 0 + ), + }, + processing_time_seconds=segment_time, + peak_memory_usage=mx.metal.get_peak_memory() / 1e9, + ) diff --git a/mlx_audio/tts/models/bark/isftnet.py b/mlx_audio/tts/models/bark/isftnet.py new file mode 100644 index 00000000..7d4985cd --- /dev/null +++ b/mlx_audio/tts/models/bark/isftnet.py @@ -0,0 +1,12 @@ +import mlx.core as mx +import mlx.nn as nn + + +# Loads to torch Encodec model +def codec_decode(codec: nn.Module, fine_tokens: mx.array): + arr = fine_tokens.astype(mx.int32)[None] + emb = codec.quantizer.decode(arr) + out = codec.decoder(emb).astype(mx.float32) + audio_arr = mx.squeeze(out, -1) + del arr, emb, out + return audio_arr diff --git a/mlx_audio/tts/models/bark/pipeline.py b/mlx_audio/tts/models/bark/pipeline.py new file mode 100644 index 00000000..01051bdb --- /dev/null +++ b/mlx_audio/tts/models/bark/pipeline.py @@ -0,0 +1,442 @@ +import math +import os +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +import tqdm + +from mlx_audio.codec.models.encodec.encodec import Encodec + +from ..base import adjust_speed +from .isftnet import codec_decode + +TEXT_ENCODING_OFFSET = 10_048 +SEMANTIC_PAD_TOKEN = 10_000 +TEXT_PAD_TOKEN = 129595 +SEMANTIC_INFER_TOKEN = 129_599 + +CONTEXT_WINDOW_SIZE = 1024 + +SEMANTIC_RATE_HZ = 49.9 +SEMANTIC_VOCAB_SIZE = 10_000 + +CODEBOOK_SIZE = 1024 +N_COARSE_CODEBOOKS = 2 +N_FINE_CODEBOOKS = 8 +COARSE_RATE_HZ = 75 +COARSE_SEMANTIC_PAD_TOKEN = 12_048 +COARSE_INFER_TOKEN = 12_050 +SAMPLE_RATE = 24_000 + +CUR_PATH = os.path.dirname(os.path.abspath(__file__)) + + +SUPPORTED_LANGS = [ + ("English", "en"), + ("German", "de"), + ("Spanish", "es"), + ("French", "fr"), + ("Hindi", "hi"), + ("Italian", "it"), + ("Japanese", "ja"), + ("Korean", "ko"), + ("Polish", "pl"), + ("Portuguese", "pt"), + ("Russian", "ru"), + ("Turkish", "tr"), + ("Chinese", "zh"), +] + +ALLOWED_PROMPTS = {"announcer"} +for _, lang in SUPPORTED_LANGS: + for prefix in ("", f"v2{os.path.sep}"): + for n in range(10): + ALLOWED_PROMPTS.add(f"{prefix}{lang}_speaker_{n}") + + +@dataclass +class Result: + audio: mx.array + tokens: mx.array + + ### MARK: BEGIN BACKWARD COMPAT ### + def __iter__(self): + yield self.audio + yield self.tokens + + def __getitem__(self, index): + return [self.audio, self.tokens][index] + + def __len__(self): + return 2 + + +def _load_voice_prompt(voice_prompt_input): + if isinstance(voice_prompt_input, str) and voice_prompt_input.endswith(".npz"): + voice_prompt = np.load(voice_prompt_input) + elif isinstance(voice_prompt_input, str): + # make sure this works on non-ubuntu + voice_prompt_input = os.path.join(*voice_prompt_input.split("/")) + if voice_prompt_input not in ALLOWED_PROMPTS: + raise ValueError("voice prompt not found") + + path = f"{voice_prompt_input}.npz" + + # TODO: Get the path from the Hugging Face cache directory + # TODO: If not found, download the voice from Hugging Face + # TODO: If still not found, raise an error + + if not os.path.exists(path): + raise ValueError("voice prompt not found") + voice_prompt = np.load(path) + elif isinstance(voice_prompt_input, dict): + assert "semantic_prompt" in voice_prompt_input + assert "coarse_prompt" in voice_prompt_input + assert "fine_prompt" in voice_prompt_input + voice_prompt = voice_prompt_input + else: + raise ValueError("voice prompt format unrecognized") + return voice_prompt + + +def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE): + assert len(arr.shape) == 2 + if offset_size is not None: + for n in range(1, arr.shape[0]): + arr[n, :] += offset_size * n + # MLX doesn't have ravel with order parameter, so we transpose and reshape + # to achieve the same effect as numpy's ravel('F') + flat_arr = arr.transpose().reshape(-1) + return flat_arr + + +class Pipeline: + def __init__(self, model: nn.Module, tokenizer: any, config: any): + self.model = model + self.tokenizer = tokenizer + self.codec_model, _ = Encodec.from_pretrained(config.codec_path) + + def generate_text_semantic( + self, + text: str, + voice: str = "announcer", + temperature: float = 0.7, + use_kv_caching: bool = False, + allow_early_stop: bool = True, + **kwargs, + ): + """Generate semantic tokens from text.""" + verbose = kwargs.get("verbose", False) + if verbose: + print("Generating semantic tokens...") + if voice is not None: + voice_prompt = _load_voice_prompt(voice) + semantic_history = mx.array(voice_prompt["semantic_prompt"]) + assert ( + isinstance(semantic_history, mx.array) + and len(semantic_history.shape) == 1 + and len(semantic_history) > 0 + and semantic_history.min() >= 0 + and semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1 + ) + else: + semantic_history = None + + encoded_text = ( + mx.array(self.tokenizer.encode(text, add_special_tokens=False)) + + TEXT_ENCODING_OFFSET + ) + if len(encoded_text) > 256: + p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1) + encoded_text = encoded_text[:256] + encoded_text = mx.pad( + encoded_text, + (0, 256 - len(encoded_text)), + constant_values=TEXT_PAD_TOKEN, + ) + if semantic_history is not None: + semantic_history = semantic_history.astype(mx.int64) + # lop off if history is too long, pad if needed + semantic_history = semantic_history[-256:] + semantic_history = mx.pad( + semantic_history, + (0, 256 - len(semantic_history)), + constant_values=SEMANTIC_PAD_TOKEN, + mode="constant", + ) + else: + semantic_history = mx.array([SEMANTIC_PAD_TOKEN] * 256) + + x = ( + mx.concatenate( + [encoded_text, semantic_history, mx.array([SEMANTIC_INFER_TOKEN])] + ) + .reshape(1, -1) + .astype(mx.int64) + ) + n_tot_steps = 768 + kv_cache = None + for i in tqdm.tqdm(range(n_tot_steps), disable=not verbose): + if use_kv_caching and kv_cache is not None: + x_input = x[:, -1:] + else: + x_input = x + logits, kv_cache = self.model.semantic( + x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache + ) + relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE] + if allow_early_stop: + # Early stop + relevant_logits = mx.concatenate( + [relevant_logits, logits[0, 0, SEMANTIC_PAD_TOKEN].reshape(1)], + axis=-1, + ) + next_token = mx.random.categorical( + relevant_logits * 1 / (temperature), num_samples=1 + ).astype(mx.int32) + + if next_token == SEMANTIC_VOCAB_SIZE: + print(f"Early stop at step {i} with token {next_token.tolist()}") + break + x = mx.concatenate([x, next_token.reshape(1, -1)], axis=1) + if i == n_tot_steps - 1: + break + out = x.squeeze()[256 + 256 + 1 :] + return out, encoded_text + + def generate_coarse( + self, + x_semantic: mx.array, + voice: str = "announcer", + temperature: float = 0.7, + max_coarse_history: int = 60, # min 60 (faster), max 630 (more context) + sliding_window_len: int = 60, + use_kv_caching: bool = False, + **kwargs, + ): + """Generate coarse tokens from semantic tokens.""" + verbose = kwargs.get("verbose", False) + if verbose: + print("Generating coarse tokens...") + semantic_to_coarse_ratio = ( + COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS + ) + max_semantic_history = int( + math.floor(max_coarse_history / semantic_to_coarse_ratio) + ) + if voice is not None: + voice_prompt = _load_voice_prompt(voice) + x_semantic_history = mx.array(voice_prompt["semantic_prompt"]) + x_coarse_history = mx.array(voice_prompt["coarse_prompt"]) + assert ( + isinstance(x_semantic_history, mx.array) + and len(x_semantic_history.shape) == 1 + and len(x_semantic_history) > 0 + and x_semantic_history.min() >= 0 + and x_semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1 + and isinstance(x_coarse_history, mx.array) + and len(x_coarse_history.shape) == 2 + and x_coarse_history.shape[0] == N_COARSE_CODEBOOKS + and x_coarse_history.shape[-1] >= 0 + and x_coarse_history.min() >= 0 + and x_coarse_history.max() <= CODEBOOK_SIZE - 1 + and ( + round(x_coarse_history.shape[-1] / len(x_semantic_history), 1) + == round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1) + ) + ) + x_coarse_history = ( + _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE + ) + # trim histories correctly + n_semantic_hist_provided = min( + max_semantic_history, + len(x_semantic_history) - len(x_semantic_history) % 2, + int(math.floor(len(x_coarse_history) / semantic_to_coarse_ratio)), + ) + n_coarse_hist_provided = int( + round(n_semantic_hist_provided * semantic_to_coarse_ratio) + ) + x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype( + mx.int32 + ) + x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype( + mx.int32 + ) + # TODO: bit of a hack for time alignment (sounds better) + x_coarse_history = x_coarse_history[:-2] + else: + x_semantic_history = mx.array([], dtype=mx.int32) + x_coarse_history = mx.array([], dtype=mx.int32) + + n_steps = int( + round( + math.floor( + len(x_semantic) * semantic_to_coarse_ratio / N_COARSE_CODEBOOKS + ) + * N_COARSE_CODEBOOKS + ) + ) + x_semantic = mx.concatenate([x_semantic_history, x_semantic]).astype(mx.int32) + x_coarse = x_coarse_history.astype(mx.int32) + base_semantic_idx = len(x_semantic_history) + # Inference + x_semantic_in = x_semantic.reshape(1, -1) + x_coarse_in = x_coarse.reshape(1, -1) + n_window_steps = int(round(n_steps / sliding_window_len)) + n_step = 0 + for _ in tqdm.tqdm( + range(n_window_steps), total=n_window_steps, disable=not verbose + ): + semantic_idx = base_semantic_idx + int( + round(n_step / semantic_to_coarse_ratio) + ) + x_in = x_semantic_in[:, max(0, semantic_idx - max_semantic_history) :] + x_in = x_in[:, :256] + x_in = mx.pad( + x_in, + ((0, 0), (0, 256 - x_in.shape[-1])), + constant_values=COARSE_SEMANTIC_PAD_TOKEN, + ) + x_in = mx.concatenate( + [ + x_in, + mx.array([COARSE_INFER_TOKEN]).reshape(1, -1), + x_coarse_in[:, -max_coarse_history:], + ], + axis=1, + ) + kv_cache = None + for _ in range(sliding_window_len): + if n_step >= n_steps: + continue + is_major_step = n_step % N_COARSE_CODEBOOKS == 0 + x_input = ( + x_in[:, -1:] if use_kv_caching and kv_cache is not None else x_in + ) + logits, kv_cache = self.model.coarse_acoustics( + x_input, use_cache=use_kv_caching, past_kv=kv_cache + ) + logit_start_idx = ( + SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE + ) + logit_end_idx = ( + SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE + ) + logit_end_idx = min(logit_end_idx, logits.shape[-1]) + relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx] + item_next = mx.random.categorical( + relevant_logits * (1 / temperature), num_samples=1 + ).astype(mx.int32) + + item_next += logit_start_idx + x_coarse_in = mx.concatenate( + [x_coarse_in, item_next.reshape(1, 1)], axis=1 + ) + x_in = mx.concatenate([x_in, item_next.reshape(1, 1)], axis=1) + n_step += 1 + + gen_coarse_arr = x_coarse_in[0, len(x_coarse_history) :] + gen_coarse_audio_arr = ( + gen_coarse_arr.reshape(-1, N_COARSE_CODEBOOKS).T - SEMANTIC_VOCAB_SIZE + ) + for n in range(1, N_COARSE_CODEBOOKS): + gen_coarse_audio_arr[n, :] -= n * CODEBOOK_SIZE + + return gen_coarse_audio_arr + + def generate_fine( + self, + x_coarse_gen: mx.array, + temperature: float = 0.7, + **kwargs, + ): + verbose = kwargs.get("verbose", False) + """Generate fine tokens from coarse tokens.""" + if verbose: + print("Generating fine tokens...") + x_fine_history = None + n_coarse = x_coarse_gen.shape[0] + in_arr = mx.concatenate( + [ + x_coarse_gen, + mx.zeros((N_FINE_CODEBOOKS - n_coarse, x_coarse_gen.shape[1])) + + CODEBOOK_SIZE, # padding + ], + axis=0, + ) + n_history = 0 + n_remove_from_end = 0 + # need to pad if too short (since non-causal model) + if in_arr.shape[1] < 1024: + n_remove_from_end = 1024 - in_arr.shape[1] + in_arr = mx.concatenate( + [ + in_arr, + mx.zeros((N_FINE_CODEBOOKS, n_remove_from_end)) + CODEBOOK_SIZE, + ], + axis=1, + ) + # Inference + n_loops = ( + max(0, int(math.ceil((x_coarse_gen.shape[1] - (1024 - n_history)) / 512))) + + 1 + ) + in_arr = in_arr.T + for n in tqdm.tqdm(range(n_loops), disable=not verbose): + start_idx = mx.min(mx.array([n * 512, in_arr.shape[0] - 1024])).item() + start_fill_idx = mx.min( + mx.array([n_history + n * 512, in_arr.shape[0] - 512]) + ).item() + rel_start_fill_idx = start_fill_idx - start_idx + in_buffer = in_arr[start_idx : start_idx + 1024, :][None] + for nn in range(n_coarse, N_FINE_CODEBOOKS): + logits = self.model.fine_acoustics(nn, in_buffer) + if temperature is None: + relevant_logits = logits[0, rel_start_fill_idx:, :CODEBOOK_SIZE] + codebook_preds = mx.argmax(relevant_logits, -1) + else: + relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temperature + codebook_preds = ( + mx.random.categorical( + relevant_logits[rel_start_fill_idx:1024], num_samples=1 + ) + .reshape(-1) + .astype(mx.int32) + ) + in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds + for nn in range(n_coarse, N_FINE_CODEBOOKS): + in_arr[ + start_fill_idx : start_fill_idx + (1024 - rel_start_fill_idx), nn + ] = in_buffer[0, rel_start_fill_idx:, nn] + gen_fine_arr = in_arr.squeeze().T + gen_fine_arr = gen_fine_arr[:, n_history:] + if n_remove_from_end > 0: + gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end] + assert gen_fine_arr.shape[-1] == x_coarse_gen.shape[-1] + return gen_fine_arr + + def __call__( + self, + text: str, + voice: str = None, + temperature: float = 0.7, + speed: float = 1.0, + use_kv_caching: bool = False, + **kwargs, + ): + semantic_tokens, tokens = self.generate_text_semantic( + text, voice, temperature, use_kv_caching, **kwargs + ) + coarse_tokens = self.generate_coarse( + semantic_tokens, voice, temperature, use_kv_caching, **kwargs + ) + fine_tokens = self.generate_fine(coarse_tokens, temperature, **kwargs) + # TODO: adjust speed + # audio_arr = adjust_speed(fine_tokens, speed) + audio_arr = codec_decode(self.codec_model, fine_tokens) + + yield Result(audio=audio_arr, tokens=tokens) diff --git a/mlx_audio/tts/models/base.py b/mlx_audio/tts/models/base.py index 535e9adc..f70bb74e 100644 --- a/mlx_audio/tts/models/base.py +++ b/mlx_audio/tts/models/base.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import mlx.core as mx +import numpy as np @dataclass @@ -33,6 +34,40 @@ def check_array_shape(arr): return False +def adjust_speed(audio_array, speed_factor): + """ + Adjust the speed of the audio by resampling + speed_factor > 1: faster + speed_factor < 1: slower + """ + # Ensure we're working with MLX arrays + if not isinstance(audio_array, mx.array): + audio_array = mx.array(audio_array) + + # Calculate new length + old_length = audio_array.shape[0] + new_length = int(old_length / speed_factor) + + # Create new time points + old_indices = mx.arange(old_length) + new_indices = mx.linspace(0, old_length - 1, new_length) + + # Resample using linear interpolation + # Since mx doesn't have interp, we'll implement it directly + indices_floor = mx.floor(new_indices).astype(mx.int32) + indices_ceil = mx.minimum(indices_floor + 1, old_length - 1) + weights_ceil = new_indices - indices_floor + weights_floor = 1.0 - weights_ceil + + # Perform the interpolation + result = ( + weights_floor.reshape(-1, 1) * audio_array[indices_floor] + + weights_ceil.reshape(-1, 1) * audio_array[indices_ceil] + ) + + return result + + @dataclass class GenerationResult: audio: mx.array diff --git a/mlx_audio/tts/models/kokoro/istftnet.py b/mlx_audio/tts/models/kokoro/istftnet.py index 5405d873..3ed327cc 100644 --- a/mlx_audio/tts/models/kokoro/istftnet.py +++ b/mlx_audio/tts/models/kokoro/istftnet.py @@ -114,15 +114,13 @@ def __init__( self.groups = groups # Initialize weight magnitude (g) and direction (v) vectors - weight_g = mx.ones((out_channels, 1, 1)) # Scalar magnitude per output channel - weight_v = mx.ones( + self.weight_g = mx.ones( + (out_channels, 1, 1) + ) # Scalar magnitude per output channel + self.weight_v = mx.ones( (out_channels, kernel_size, in_channels) ) # Direction vectors - # Store parameters - self.weight_g = mx.array(weight_g) - self.weight_v = mx.array(weight_v) - self.bias = mx.zeros(in_channels if encode else out_channels) if bias else None def __call__(self, x, conv): diff --git a/mlx_audio/tts/models/kokoro/kokoro.py b/mlx_audio/tts/models/kokoro/kokoro.py index 29f95dd2..f4aca8ca 100644 --- a/mlx_audio/tts/models/kokoro/kokoro.py +++ b/mlx_audio/tts/models/kokoro/kokoro.py @@ -244,11 +244,10 @@ def sanitize(self, weights): def generate( self, text: str, - voice: str = "af_heart", + voice: str = None, speed: float = 1.0, lang_code: str = "af", split_pattern: str = r"\n+", - verbose: bool = False, **kwargs, ): pipeline = KokoroPipeline( @@ -257,6 +256,9 @@ def generate( lang_code=lang_code, ) + if voice is None: + voice = "af_heart" + # Track overall generation time start_time = time.time() diff --git a/mlx_audio/tts/tests/test_models.py b/mlx_audio/tts/tests/test_models.py index 33650158..76def0d2 100644 --- a/mlx_audio/tts/tests/test_models.py +++ b/mlx_audio/tts/tests/test_models.py @@ -334,5 +334,194 @@ def test_result_dataclass(self): self.assertIs(items[2], audio) +@patch("importlib.resources.open_text", patched_open_text) +class TestBarkModel(unittest.TestCase): + @patch("mlx_audio.tts.models.bark.bark.BertTokenizer") + def test_init(self, mock_tokenizer): + """Test BarkModel initialization.""" + from mlx_audio.tts.models.bark.bark import ( + CoarseAcousticsConfig, + CodecConfig, + FineAcousticsConfig, + Model, + ModelConfig, + SemanticConfig, + ) + + # Create mock configs + semantic_config = SemanticConfig() + coarse_config = CoarseAcousticsConfig() + fine_config = FineAcousticsConfig() + codec_config = CodecConfig() + + config = ModelConfig( + semantic_config=semantic_config, + coarse_acoustics_config=coarse_config, + fine_acoustics_config=fine_config, + codec_config=codec_config, + ) + + # Initialize model + model = Model(config) + + # Check that components were initialized correctly + self.assertIsNotNone(model.semantic) + self.assertIsNotNone(model.coarse_acoustics) + self.assertIsNotNone(model.fine_acoustics) + self.assertIsNotNone(model.tokenizer) + + def test_sanitize_weights(self): + """Test weight sanitization.""" + from mlx_audio.tts.models.bark.bark import Model, ModelConfig + + # Create a minimal config + config = ModelConfig( + semantic_config={}, + coarse_acoustics_config={}, + fine_acoustics_config={}, + codec_config={}, + ) + + model = Model(config) + + # Test with transformer weights + weights = { + "_orig_mod.transformer.h.0.mlp.weight": mx.zeros((10, 10)), + "_orig_mod.transformer.h.1.mlp.weight": mx.zeros((10, 10)), + "lm_head.weight": mx.zeros((10, 10)), + } + + sanitized = model.sanitize(weights) + + # Check that weights were properly renamed + self.assertIn("layers.0.mlp.weight", sanitized) + self.assertIn("layers.1.mlp.weight", sanitized) + self.assertIn("lm_head.weight", sanitized) + + +@patch("importlib.resources.open_text", patched_open_text) +class TestBarkPipeline(unittest.TestCase): + def setUp(self): + """Set up test fixtures.""" + from mlx_audio.tts.models.bark.bark import ( + CoarseAcousticsConfig, + CodecConfig, + FineAcousticsConfig, + Model, + ModelConfig, + SemanticConfig, + ) + from mlx_audio.tts.models.bark.pipeline import Pipeline + + # Create mock model with required attributes + self.mock_model = MagicMock(spec=Model) + + # Add the required mock attributes/methods + self.mock_model.semantic = MagicMock() + self.mock_model.coarse_acoustics = MagicMock() + self.mock_model.fine_acoustics = MagicMock() + self.mock_model.codec_model = MagicMock() + + self.mock_tokenizer = MagicMock() + + # Initialize pipeline + self.pipeline = Pipeline( + model=self.mock_model, + tokenizer=self.mock_tokenizer, + config=ModelConfig( + semantic_config=SemanticConfig(), + coarse_acoustics_config=CoarseAcousticsConfig(), + fine_acoustics_config=FineAcousticsConfig(), + codec_config=CodecConfig(), + ), + ) + + def test_generate_text_semantic(self): + """Test semantic token generation.""" + # Mock tokenizer output + self.mock_tokenizer.encode.return_value = [1, 2, 3] + + # Create logits with proper shape including SEMANTIC_PAD_TOKEN + logits = mx.zeros((1, 1, 129596)) # Large enough to include SEMANTIC_PAD_TOKEN + # Mock model output + self.mock_model.semantic.return_value = ( + logits, # logits with correct shape + None, # kv_cache + ) + + # Test generation + semantic_tokens, text_tokens = self.pipeline.generate_text_semantic( + "test text", + temperature=0.7, + use_kv_caching=True, + voice=None, + ) + + # Verify tokenizer was called + self.mock_tokenizer.encode.assert_called_once_with( + "test text", add_special_tokens=False + ) + + # Verify model was called + self.mock_model.semantic.assert_called() + + # Check output types + self.assertIsInstance(semantic_tokens, mx.array) + self.assertIsInstance(text_tokens, mx.array) + + @patch("mlx.core.random.categorical") # Add this patch since we use mx alias + def test_generate_coarse(self, mock_mlx_categorical): + """Test coarse token generation.""" + # Create mock semantic tokens + semantic_tokens = mx.array([1, 2, 3]) + + # Create logits with proper shape + logits = mx.zeros((1, 1, 12096)) + + # Mock both categorical functions to return predictable values + mock_mlx_categorical.return_value = mx.array([10000]) # Return token index + + # Set up the mock to return proper values for each call + self.mock_model.coarse_acoustics.return_value = (logits, None) + + # Test generation with minimal parameters to reduce test time + coarse_tokens = self.pipeline.generate_coarse( + semantic_tokens, + temperature=0.7, + use_kv_caching=True, + voice=None, + max_coarse_history=60, + sliding_window_len=2, # Reduce this to minimum + ) + + # Verify model was called at least once + self.mock_model.coarse_acoustics.assert_called() + + # Check output type and shape + self.assertIsInstance(coarse_tokens, mx.array) + self.assertEqual(coarse_tokens.shape[0], 2) # N_COARSE_CODEBOOKS + + def test_generate_fine(self): + """Test fine token generation.""" + # Create mock coarse tokens + coarse_tokens = mx.zeros((2, 100)) # N_COARSE_CODEBOOKS x sequence_length + + # Mock model output with proper shape + self.mock_model.fine_acoustics.return_value = mx.zeros((1, 1024, 1024)) + + # Test generation + fine_tokens = self.pipeline.generate_fine(coarse_tokens, temperature=0.7) + + # Verify model was called + self.mock_model.fine_acoustics.assert_called() + + # Check output type and shape + self.assertIsInstance(fine_tokens, mx.array) + self.assertEqual( + fine_tokens.shape[0], 8 + ) # N_FINE_CODEBOOKS (corrected from 10 to 8) + self.assertEqual(fine_tokens.shape[1], 100) # sequence_length + + if __name__ == "__main__": unittest.main() diff --git a/mlx_audio/tts/utils.py b/mlx_audio/tts/utils.py index 81df5bc0..c42d312e 100644 --- a/mlx_audio/tts/utils.py +++ b/mlx_audio/tts/utils.py @@ -104,6 +104,7 @@ def load_model(model_path: Path, lazy: bool = False, **kwargs) -> nn.Module: if isinstance(model_path, str): name = model_path.split("/")[-1].split("-")[0].lower() model_path = get_model_path(model_path) + config = load_config(model_path, **kwargs) model_type = config.get("model_type", name) @@ -138,7 +139,8 @@ def load_model(model_path: Path, lazy: bool = False, **kwargs) -> nn.Module: weights.update(mx.load(wf)) model_class, model_type = get_model_and_args(model_type=model_type) - model = model_class.Model(config) + model_config = model_class.ModelConfig.from_dict(config) + model = model_class.Model(model_config) quantization = config.get("quantization", None) if quantization is None: weights = model.sanitize(weights) diff --git a/mlx_audio/version.py b/mlx_audio/version.py index 3b93d0be..27fdca49 100644 --- a/mlx_audio/version.py +++ b/mlx_audio/version.py @@ -1 +1 @@ -__version__ = "0.0.2" +__version__ = "0.0.3" diff --git a/requirements.txt b/requirements.txt index 50d9efaa..ef4dff9b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ sounddevice>=0.5.1 soundfile>=0.13.1 fastapi>=0.95.0 uvicorn>=0.22.0 +encodec>=0.1.1