Skip to content

Add Voxtral-4B-TTS-2603 support#607

Merged
lucasnewman merged 21 commits intoBlaizzy:mainfrom
shreyaskarnik:voxtral-tts-support
Mar 27, 2026
Merged

Add Voxtral-4B-TTS-2603 support#607
lucasnewman merged 21 commits intoBlaizzy:mainfrom
shreyaskarnik:voxtral-tts-support

Conversation

@shreyaskarnik
Copy link
Copy Markdown
Contributor

@shreyaskarnik shreyaskarnik commented Mar 26, 2026

Summary

Adds support for mistralai/Voxtral-4B-TTS-2603 — a 4B parameter text-to-speech model with 20 voice presets across 9 languages.

Closes #606

Usage

from mlx_audio.tts.utils import load
import sounddevice as sd

model = load("mlx-community/Voxtral-4B-TTS-2603-mlx-bf16")

for result in model.generate(text="Hello, how are you today?", voice="casual_male"):
    sd.play(result.audio, result.sample_rate)
    sd.wait()

Voices: casual_male, casual_female, cheerful_female, neutral_male, neutral_female, fr_male, fr_female, es_male, es_female, de_male, de_female, it_male, it_female, pt_male, pt_female, nl_male, nl_female, ar_male, hi_male, hi_female

Languages: English, French, Spanish, German, Italian, Portuguese, Dutch, Arabic, Hindi

Architecture

  • LM Backbone (~3.4B): Mistral decoder via mlx-lm LlamaModel (26 layers, 3072 dim, GQA 32/8 heads, traditional RoPE)
  • Flow-Matching Acoustic Transformer (~390M): 3-layer bidirectional transformer, 8-step Euler ODE with CFG (alpha=1.2)
  • Audio Tokenizer Decoder (~200M): 4-stage ConvTranspose1d + ALiBi causal transformer vocoder → 24kHz waveform

Key implementation details

  • Traditional (interleaved) RoPE — Voxtral uses `rope_traditional=True`, not the NeoX style. This was the critical fix that made speech intelligible (verified: cos_sim=0.999984 vs C reference implementation).
  • Time embedding (cos, sin) — flow matching uses (cos, sin) concatenation order, matching vllm-omni convention.
  • Vocoder ALiBi — causal attention with sliding windows [2, 4, 8, 16] per decoder stage.
  • Semantic codebook — EMA statistics computed in float32 for precision.
  • Voice embeddings — loaded from .safetensors (no torch dependency). Voice embeddings replace audio token positions in the prompt.

Files

File Purpose
voxtral_tts/__init__.py Module entry
voxtral_tts/common.py Shared FeedForward + pad_to_multiple
voxtral_tts/acoustic_head.py Flow-matching acoustic transformer
voxtral_tts/audio_tokenizer.py Vocoder (conv + ALiBi transformer decoder)
voxtral_tts/voxtral_tts.py Main model, config, weight sanitization, generation
voxtral_tts/README.md Model README with usage examples
tts/utils.py MODEL_REMAPPING entry
utils.py params.json fallback, *.safetensors voice pattern

Test results (mlx-community/Voxtral-4B-TTS-2603-mlx-bf16)

Language Voice Frames Duration
English casual_male 54 4.32s
French fr_male 31 2.48s
Spanish es_female 33 2.64s
German de_female 22 1.76s
Italian it_male 23 1.84s
Portuguese pt_male 29 2.32s
Dutch nl_male 29 2.32s
Arabic ar_male 30 2.40s
Hindi hi_female 28 2.24s

shreyaskarnik and others added 8 commits March 26, 2026 13:45
Add support for mistralai/Voxtral-4B-TTS-2603, a 4B parameter
text-to-speech model with 20 voice presets across 9 languages.

Architecture:
- LM backbone (Mistral 3.4B, 26 layers) with KV-cached autoregressive generation
- Flow-matching acoustic transformer (3 layers, 8-step Euler with CFG alpha=1.2)
- Audio tokenizer decoder (4-stage ConvTranspose + ALiBi transformer vocoder)

Features:
- Loads from both original checkpoint (params.json) and MLX-converted format
- 20 voice presets (casual/neutral/cheerful + 9 languages)
- Tokenization via mistral_common SpeechRequest or HF fallback
- Voice embedding injection (sum with audio token embeddings)
- Two-stage generation: LM hidden states → acoustic codes → 24kHz waveform

Known limitations:
- Audio quality needs improvement (flow matching produces warbling)
- Generation is slow (~2 frames/s due to per-frame flow matching)
- End-of-audio detection may not trigger reliably

MLX-converted model: mlx-community/Voxtral-4B-TTS-2603-mlx-bf16
- Fix ALiBi distance computation: use (i - j) not (j - i)
- Fix sliding window mask direction: mask dist > window (past too far back)
- Add sliding window sizes [2, 4, 8, 16] per decoder stage
- Fix semantic codebook decode: compute in float32 to avoid bf16 precision loss
- Add first AUDIO token decode step matching C reference implementation
- Handle both original and MLX-converted weight name formats in sanitize()

Verified against PyTorch reference:
- Acoustic transformer velocity: max diff < 6e-6
- Vocoder conv1d: max diff < 1e-6
- Vocoder conv_transpose1d: max diff < 2e-6
- Vocoder transformer layer: max diff < 2e-5
- Codebook decode: max diff < 1e-6 (was 0.97 before float32 fix)

Remaining issue: LLM hidden states converge to fixed point across
frames, causing repetitive semantic codes. Needs investigation of
the code embedding feedback loop.
The flow matching TimeEmbedding was using (sin, cos) concatenation
order but the checkpoint weights expect (cos, sin) order (matching
vllm-omni and ExecuTorch implementations).

This was the primary cause of garbled audio:
- (sin, cos): 14/36 acoustic values clipped outside [-1,1], flow diverges
- (cos, sin): 0/36 clipped, flow converges to [-0.63, 0.83]

Verified: acoustic transformer output matches PyTorch with max diff < 1e-6.
Semantic code diversity improved from 2 to 6 unique codes.
Audio now sounds voice-like (previously garbled noise).

Remaining: model generates too long (no EOS) and hidden states
converge after ~7 frames.
Key findings from comparing voxtral-tts.c on macOS:
- C code also produces mostly code 10 (most common VQ entry)
- C code produces "Hello world" non-deterministically (depends on
  flow matching random noise)
- Prompt embeddings match perfectly (cos_sim=1.0 at every position)
- Layer 0 output: cos_sim=0.969 (bf16 rounding starts here)
- Final layer output: cos_sim=0.48 (accumulated across 26 layers)
- Both C and MLX produce incoherent audio most of the time

The model likely requires GPU f16/f32 precision (as in vllm-omni on H100)
to consistently produce coherent speech. The bf16 precision on Apple
Silicon introduces enough numerical difference across 26 transformer
layers to shift argmax winners in the 8194-dim semantic codebook.
Voxtral TTS uses traditional/interleaved RoPE where pairs (x[2i], x[2i+1])
are rotated together, NOT the NeoX/GPT-NeoX style where first/second halves
are split. mlx-lm's LlamaModel defaults to non-traditional (NeoX) which
caused hidden states to diverge completely from the reference C implementation
(cos_sim=0.48 after 26 layers).

With rope_traditional=True:
- Hidden states match C reference: cos_sim=0.999984
- Model produces diverse semantic codes (12 unique vs 2 before)
- Natural EOS detection at frame 26
- Intelligible speech output ("Hello world" clearly audible)

This was the root cause of all audio quality issues.
@lucasnewman
Copy link
Copy Markdown
Collaborator

@shreyaskarnik Overall this looks very good, thanks! I left a couple of comments, and we'll also want a README.md pointing to the supported HF repo(s) and some basic usages examples -- see the other models' README files for examples.

- Load voice embeddings from .safetensors (no torch needed)
- Fallback to .pt with graceful ImportError if torch unavailable
- Fix voice embedding docstring (replace, not sum)
- Converted .safetensors voice files available at /tmp/voice_safetensors
  (need mlx-community org access to upload to HF)
@chigkim
Copy link
Copy Markdown
Contributor

chigkim commented Mar 27, 2026

The following generates some very weird speech that you can't understand! Mostly noise.
import sounddevice as sd
model = load("mlx-community/Voxtral-4B-TTS-2603-mlx-bf16")
res = list(model.generate(text="Hello, this is a test of Voxtral text-to-speech!", voice="casual_male")
sd.play(res[0].audio, 24000)

@shreyaskarnik
Copy link
Copy Markdown
Contributor Author

Status: Working! 🎉

The root cause was RoPE convention — Voxtral TTS uses traditional (interleaved) RoPE, not the NeoX style that mlx-lm defaults to for Llama models. One-line fix: rope_traditional=True.

Verified against the C reference (voxtral-tts.c):

  • Hidden states match: cos_sim=0.999984 (was 0.48 before the fix)
  • Natural EOS detection at correct frame counts
  • Intelligible speech across all 9 languages

Test results (mlx-community/Voxtral-4B-TTS-2603-mlx-bf16):

Language Voice Frames Duration
French fr_male 31 2.48s
Spanish es_female 33 2.64s
German de_female 22 1.76s
Italian it_male 23 1.84s
Portuguese pt_male 29 2.32s
Dutch nl_male 29 2.32s
Arabic ar_male 30 2.40s
Hindi hi_female 28 2.24s

Also addressed review comments:

  • Voice loading now supports .safetensors (no torch dependency), with .pt fallback
  • Converted .safetensors voice files ready to upload to hub

@shreyaskarnik
Copy link
Copy Markdown
Contributor Author

shreyaskarnik commented Mar 27, 2026

The following generates some very weird speech that you can't understand! Mostly noise. import sounddevice as sd model = load("mlx-community/Voxtral-4B-TTS-2603-mlx-bf16") res = list(model.generate(text="Hello, this is a test of Voxtral text-to-speech!", voice="casual_male") sd.play(res[0].audio, 24000)

@chigkim worked for me there was some noise before. Working on making quality better along with throughput. Uploaded all safetensor weights to mlx-community/Voxtral-4B-TTS-2603-mlx-bf16 can you please retry? Thank you for testing!

@shreyaskarnik
Copy link
Copy Markdown
Contributor Author

Quantized variants uploaded + throughput benchmarks

Three variants now available on mlx-community:

Variant RTF (short) RTF (long) Size Link
4-bit 0.97x 0.74x ~2.5GB mlx-community/Voxtral-4B-TTS-2603-mlx-4bit
6-bit 1.15x 1.07x ~3.5GB mlx-community/Voxtral-4B-TTS-2603-mlx-6bit
bf16 6.50x 6.32x ~8GB mlx-community/Voxtral-4B-TTS-2603-mlx-bf16

RTF = Real-Time Factor (lower = faster, <1.0 = faster than real-time). Benchmarked on Apple Silicon (18GB).

4-bit is ~8x faster than bf16 and runs faster than real-time. Quality is good across all variants and all 9 languages.

@shreyaskarnik
Copy link
Copy Markdown
Contributor Author

thanks so much @lucasnewman for the review have addressed your comments and feedback. There is still a issue of breath/noise bursts in the samples that I generated. If you can test with some examples that would be great. cc: @Blaizzy

ALiBi bias should be slope * (j - i), giving negative bias for past
positions (penalizing distant lookback). Was incorrectly using (i - j)
which rewarded distant lookback, causing singing/echo artifacts.

Verified against voxtral-tts.c: score[h,i,j] += slope[h] * (j - i)
@morrygroix
Copy link
Copy Markdown

Bug: es_female voice sounds Portuguese

Tested on Mac Studio M3 Ultra (256GB RAM) with both 4-bit and 6-bit quantized models.

Steps to reproduce:

from mlx_audio.tts.utils import load
model = load('mlx-community/Voxtral-4B-TTS-2603-mlx-4bit')
for result in model.generate(
    text='Buenos días. Anoche llovió bastante y las calles amanecieron mojadas. Hoy parece que sale el sol por fin.',
    voice='es_female'
):
    pass

Expected: Spanish female voice
Actual: Strong Portuguese accent, clearly not Spanish

Notes:

  • es_male works correctly — sounds like proper Spanish
  • es_female has the issue in both 4-bit and 6-bit variants (slightly better in 6-bit but still clearly Portuguese)
  • Other female voices (casual_female, neutral_female) read Spanish text with English accent, as expected
  • Suspect the es_female voice embedding may be swapped or blended with pt_female

Great work on the port otherwise — performance is excellent (RTF ~0.21x on M3 Ultra 4-bit). Looking forward to voice cloning support too!

@Benjoyo
Copy link
Copy Markdown

Benjoyo commented Mar 27, 2026

Thanks for working on this so quickly!

Unfortunately I only seem to get gibberish using 4 and 6bit with German, English and different voices. Just "De de de de der de de" and mumbling

@nickludlam
Copy link
Copy Markdown

nickludlam commented Mar 27, 2026

Like @Benjoyo I get a large amount of nonsense audio out when I tried this command on my M1 Studio:

mlx_audio.tts.generate --verbose --model mlx-community/Voxtral-4B-TTS-2603-mlx-bf16 --text "This is a test of support for Mistral's new Voxtral 4B TTS 2603 model in mlx-audio" --output_path ./my_audio --voice cheerful_female

✅ Audio successfully generated and saving as: ./my_audio/audio_000.wav
==========
Duration:              00:01:36.000
Samples/sec:           24000.0
Prompt:                1200 tokens, 7.0 tokens-per-sec
Audio:                 2304000 samples, 24000.0 samples-per-sec
Real-time factor:      0.56x
Processing time:       170.91s
Peak memory usage:     9.56GB

@lucasnewman
Copy link
Copy Markdown
Collaborator

@shreyaskarnik I sent you a PR against your branch with some fixes to get everything working properly. If you can merge those in and address the comments above we can merge this.

Fix tokenizer, text normalization, and causal conv padding
@shreyaskarnik
Copy link
Copy Markdown
Contributor Author

@lucasnewman thanks a lot! merged your PR!

@lucasnewman
Copy link
Copy Markdown
Collaborator

@shreyaskarnik I sent you one more fix to lazy load the voice embeddings, since it's pretty slow to load them all at CLI invocation time

@shreyaskarnik
Copy link
Copy Markdown
Contributor Author

awesome perf optimization @lucasnewman merged

Copy link
Copy Markdown
Collaborator

@lucasnewman lucasnewman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

@lucasnewman lucasnewman merged commit 4b3d404 into Blaizzy:main Mar 27, 2026
10 checks passed
"*.txt",
"*.jsonl",
"*.yaml",
"*.wav",
Copy link
Copy Markdown
Contributor Author

@shreyaskarnik shreyaskarnik Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucasnewman
mlx_audio/tts/models/voxcpm/voxcpm.py:111 still expects audiovae.pth after download.
That can break fresh VoxCPM loads, shall I revert this?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shreyaskarnik Can you send a new PR?

@ldub
Copy link
Copy Markdown

ldub commented Mar 27, 2026

hey, i've also been working on adding voxtral support to mlx-audio. my branch is here, ldub#1, I didn't open a PR here since i noticed you already opened one! reading your PR has helped me learn about this codebase and how this should be done, since this is my first time working on mlx-audio.

however, i had one question, why did you not re-use the official mistral-supported tokenizer from the mistral_common package? is it too heavy of a dependency for this?

@lucasnewman
Copy link
Copy Markdown
Collaborator

however, i had one question, why did you not re-use the official mistral-supported tokenizer from the mistral_common package? is it too heavy of a dependency for this?

The official tokenizer is used when it's already installed, fwiw, but we prefer to avoid adding new dependencies as we have other projects downstream like mlx-audio-swift that can't use python dependencies, which is why I added the fallback tokenizer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add Voxtral-4B-TTS-2603 support

7 participants