Skip to content

Commit a73408b

Browse files
committed
fix(vibevoice): preserve quantization metadata in sanitize() for quantized model loading
sanitize() drops weight keys not found in the model's current parameter shapes. Since the model isn't quantized yet at sanitize time, quantization metadata keys (.scales, .biases) are silently removed. Later, apply_quantization() checks for these keys to decide which layers to quantize -- finds nothing -- skips quantization -- and loading fails with a shape mismatch. Preserve .scales and .biases keys through sanitization, matching the existing pattern in chatterbox/s3gen. Same class of bug as Blaizzy#584 (fish_qwen3_omni sanitize fix).
1 parent 6c513de commit a73408b

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

mlx_audio/tts/models/vibevoice/vibevoice.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,9 @@ def transform_key(key: str) -> str:
250250

251251
# Check if key exists in model
252252
if new_key not in curr_shapes:
253-
# Debug: uncomment to see missing keys
254-
# print(f"Warning: Key {new_key} (from {k}) not found in model")
253+
# Preserve quantization metadata -- model isn't quantized yet at sanitize time
254+
if new_key.endswith((".scales", ".biases")):
255+
new_weights[new_key] = v
255256
continue
256257

257258
target_shape = curr_shapes[new_key]

mlx_audio/tts/tests/test_models.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,6 +1318,31 @@ def test_sanitize_huggingface_keys(self):
13181318
self.assertNotIn("model.prediction_head.t_embedder.mlp.0.weight", sanitized)
13191319
self.assertNotIn("model.prediction_head.adaLN_modulation.1.weight", sanitized)
13201320

1321+
def test_sanitize_preserves_quantization_metadata(self):
1322+
"""Test that sanitize preserves .scales and .biases for quantized models."""
1323+
from mlx.utils import tree_flatten
1324+
1325+
from mlx_audio.tts.models.vibevoice.vibevoice import Model
1326+
1327+
config = self._default_config
1328+
model = Model(config)
1329+
1330+
# Start with the model's own weights
1331+
weights = dict(tree_flatten(model.parameters()))
1332+
1333+
# Add mock quantization metadata for the key from the bug report:
1334+
# "Expected shape (151936, 896) but received shape (151936, 224)
1335+
# for parameter language_model.embed_tokens.weight"
1336+
quant_key = "language_model.embed_tokens.weight"
1337+
weights[f"{quant_key}.scales"] = mx.ones((1,))
1338+
weights[f"{quant_key}.biases"] = mx.ones((1,))
1339+
1340+
sanitized = model.sanitize(weights)
1341+
1342+
# Quantization metadata must survive sanitization
1343+
self.assertIn(f"{quant_key}.scales", sanitized)
1344+
self.assertIn(f"{quant_key}.biases", sanitized)
1345+
13211346
def test_config_defaults(self):
13221347
"""Test VibeVoiceModel uses correct config defaults."""
13231348
from mlx_audio.tts.models.vibevoice.config import ModelConfig

0 commit comments

Comments
 (0)