diff --git a/mlx_audio/tts/models/fish_qwen3_omni/fish_speech.py b/mlx_audio/tts/models/fish_qwen3_omni/fish_speech.py index 47ccdd96..79b5d9de 100644 --- a/mlx_audio/tts/models/fish_qwen3_omni/fish_speech.py +++ b/mlx_audio/tts/models/fish_qwen3_omni/fish_speech.py @@ -407,6 +407,14 @@ def sample_rate(self) -> int: def model_type(self) -> str: return "fish_speech" + @classmethod + def model_quant_predicate(cls, path: str, module) -> bool: + import mlx.nn as nn + return ( + not isinstance(module, nn.Embedding) + and "fast_" not in path + ) + def load_weights(self, weights, strict: bool = True): remapped = [] for key, value in weights: @@ -418,6 +426,9 @@ def load_weights(self, weights, strict: bool = True): def sanitize(self, weights: dict[str, mx.array]) -> dict[str, mx.array]: remapped = {} for key, value in weights.items(): + if key.startswith("model."): + remapped[key] = value + continue if key.startswith("text_model.model."): new_key = key[len("text_model.model.") :] elif key.startswith("audio_decoder."): diff --git a/mlx_audio/tts/models/qwen3_tts/README.md b/mlx_audio/tts/models/qwen3_tts/README.md index 7f25a833..b2502ccb 100644 --- a/mlx_audio/tts/models/qwen3_tts/README.md +++ b/mlx_audio/tts/models/qwen3_tts/README.md @@ -4,7 +4,9 @@ Alibaba's state-of-the-art multilingual TTS with three model variants. ## Voice Cloning -Clone any voice using a reference audio sample. Provide the wav file and its transcript: +Clone any voice using a reference audio sample. Provide the wav file and its exact transcript as a literal string (not a file path): + +> **Note:** Voice cloning is only supported on the **Base** models. Do not provide a `voice` (speaker name) when using `ref_audio` and `ref_text`, as it will cause a configuration conflict. ```python from mlx_audio.tts.utils import load_model