Add Voxtral-4B-TTS-2603 support#607
Conversation
Add support for mistralai/Voxtral-4B-TTS-2603, a 4B parameter text-to-speech model with 20 voice presets across 9 languages. Architecture: - LM backbone (Mistral 3.4B, 26 layers) with KV-cached autoregressive generation - Flow-matching acoustic transformer (3 layers, 8-step Euler with CFG alpha=1.2) - Audio tokenizer decoder (4-stage ConvTranspose + ALiBi transformer vocoder) Features: - Loads from both original checkpoint (params.json) and MLX-converted format - 20 voice presets (casual/neutral/cheerful + 9 languages) - Tokenization via mistral_common SpeechRequest or HF fallback - Voice embedding injection (sum with audio token embeddings) - Two-stage generation: LM hidden states → acoustic codes → 24kHz waveform Known limitations: - Audio quality needs improvement (flow matching produces warbling) - Generation is slow (~2 frames/s due to per-frame flow matching) - End-of-audio detection may not trigger reliably MLX-converted model: mlx-community/Voxtral-4B-TTS-2603-mlx-bf16
- Fix ALiBi distance computation: use (i - j) not (j - i) - Fix sliding window mask direction: mask dist > window (past too far back) - Add sliding window sizes [2, 4, 8, 16] per decoder stage - Fix semantic codebook decode: compute in float32 to avoid bf16 precision loss - Add first AUDIO token decode step matching C reference implementation - Handle both original and MLX-converted weight name formats in sanitize() Verified against PyTorch reference: - Acoustic transformer velocity: max diff < 6e-6 - Vocoder conv1d: max diff < 1e-6 - Vocoder conv_transpose1d: max diff < 2e-6 - Vocoder transformer layer: max diff < 2e-5 - Codebook decode: max diff < 1e-6 (was 0.97 before float32 fix) Remaining issue: LLM hidden states converge to fixed point across frames, causing repetitive semantic codes. Needs investigation of the code embedding feedback loop.
The flow matching TimeEmbedding was using (sin, cos) concatenation order but the checkpoint weights expect (cos, sin) order (matching vllm-omni and ExecuTorch implementations). This was the primary cause of garbled audio: - (sin, cos): 14/36 acoustic values clipped outside [-1,1], flow diverges - (cos, sin): 0/36 clipped, flow converges to [-0.63, 0.83] Verified: acoustic transformer output matches PyTorch with max diff < 1e-6. Semantic code diversity improved from 2 to 6 unique codes. Audio now sounds voice-like (previously garbled noise). Remaining: model generates too long (no EOS) and hidden states converge after ~7 frames.
Key findings from comparing voxtral-tts.c on macOS: - C code also produces mostly code 10 (most common VQ entry) - C code produces "Hello world" non-deterministically (depends on flow matching random noise) - Prompt embeddings match perfectly (cos_sim=1.0 at every position) - Layer 0 output: cos_sim=0.969 (bf16 rounding starts here) - Final layer output: cos_sim=0.48 (accumulated across 26 layers) - Both C and MLX produce incoherent audio most of the time The model likely requires GPU f16/f32 precision (as in vllm-omni on H100) to consistently produce coherent speech. The bf16 precision on Apple Silicon introduces enough numerical difference across 26 transformer layers to shift argmax winners in the 8194-dim semantic codebook.
Voxtral TTS uses traditional/interleaved RoPE where pairs (x[2i], x[2i+1])
are rotated together, NOT the NeoX/GPT-NeoX style where first/second halves
are split. mlx-lm's LlamaModel defaults to non-traditional (NeoX) which
caused hidden states to diverge completely from the reference C implementation
(cos_sim=0.48 after 26 layers).
With rope_traditional=True:
- Hidden states match C reference: cos_sim=0.999984
- Model produces diverse semantic codes (12 unique vs 2 before)
- Natural EOS detection at frame 26
- Intelligible speech output ("Hello world" clearly audible)
This was the root cause of all audio quality issues.
|
@shreyaskarnik Overall this looks very good, thanks! I left a couple of comments, and we'll also want a README.md pointing to the supported HF repo(s) and some basic usages examples -- see the other models' README files for examples. |
- Load voice embeddings from .safetensors (no torch needed) - Fallback to .pt with graceful ImportError if torch unavailable - Fix voice embedding docstring (replace, not sum) - Converted .safetensors voice files available at /tmp/voice_safetensors (need mlx-community org access to upload to HF)
|
The following generates some very weird speech that you can't understand! Mostly noise. |
Status: Working! 🎉The root cause was RoPE convention — Voxtral TTS uses traditional (interleaved) RoPE, not the NeoX style that mlx-lm defaults to for Llama models. One-line fix: Verified against the C reference (voxtral-tts.c):
Test results (mlx-community/Voxtral-4B-TTS-2603-mlx-bf16):
Also addressed review comments:
|
@chigkim worked for me there was some noise before. Working on making quality better along with throughput. Uploaded all safetensor weights to |
Quantized variants uploaded + throughput benchmarksThree variants now available on mlx-community:
RTF = Real-Time Factor (lower = faster, <1.0 = faster than real-time). Benchmarked on Apple Silicon (18GB). 4-bit is ~8x faster than bf16 and runs faster than real-time. Quality is good across all variants and all 9 languages. |
|
thanks so much @lucasnewman for the review have addressed your comments and feedback. There is still a issue of breath/noise bursts in the samples that I generated. If you can test with some examples that would be great. cc: @Blaizzy |
ALiBi bias should be slope * (j - i), giving negative bias for past positions (penalizing distant lookback). Was incorrectly using (i - j) which rewarded distant lookback, causing singing/echo artifacts. Verified against voxtral-tts.c: score[h,i,j] += slope[h] * (j - i)
Bug:
|
|
Thanks for working on this so quickly! Unfortunately I only seem to get gibberish using 4 and 6bit with German, English and different voices. Just "De de de de der de de" and mumbling |
|
Like @Benjoyo I get a large amount of nonsense audio out when I tried this command on my M1 Studio:
|
|
@shreyaskarnik I sent you a PR against your branch with some fixes to get everything working properly. If you can merge those in and address the comments above we can merge this. |
Fix tokenizer, text normalization, and causal conv padding
|
@lucasnewman thanks a lot! merged your PR! |
|
@shreyaskarnik I sent you one more fix to lazy load the voice embeddings, since it's pretty slow to load them all at CLI invocation time |
Lazy load voice embeddings
|
awesome perf optimization @lucasnewman merged |
| "*.txt", | ||
| "*.jsonl", | ||
| "*.yaml", | ||
| "*.wav", |
There was a problem hiding this comment.
@lucasnewman
mlx_audio/tts/models/voxcpm/voxcpm.py:111 still expects audiovae.pth after download.
That can break fresh VoxCPM loads, shall I revert this?
|
hey, i've also been working on adding voxtral support to mlx-audio. my branch is here, ldub#1, I didn't open a PR here since i noticed you already opened one! reading your PR has helped me learn about this codebase and how this should be done, since this is my first time working on mlx-audio. however, i had one question, why did you not re-use the official mistral-supported tokenizer from the mistral_common package? is it too heavy of a dependency for this? |
The official tokenizer is used when it's already installed, fwiw, but we prefer to avoid adding new dependencies as we have other projects downstream like mlx-audio-swift that can't use python dependencies, which is why I added the fallback tokenizer. |
Summary
Adds support for mistralai/Voxtral-4B-TTS-2603 — a 4B parameter text-to-speech model with 20 voice presets across 9 languages.
Closes #606
Usage
Voices: casual_male, casual_female, cheerful_female, neutral_male, neutral_female, fr_male, fr_female, es_male, es_female, de_male, de_female, it_male, it_female, pt_male, pt_female, nl_male, nl_female, ar_male, hi_male, hi_female
Languages: English, French, Spanish, German, Italian, Portuguese, Dutch, Arabic, Hindi
Architecture
Key implementation details
Files
voxtral_tts/__init__.pyvoxtral_tts/common.pyvoxtral_tts/acoustic_head.pyvoxtral_tts/audio_tokenizer.pyvoxtral_tts/voxtral_tts.pyvoxtral_tts/README.mdtts/utils.pyutils.pyTest results (mlx-community/Voxtral-4B-TTS-2603-mlx-bf16)