Skip to content

Commit f3cd320

Browse files
authored
Add VibeVoice (#295)
* add vibe-voice * fix voice crackling (silu -> gelu) and add ddpm_steps * remove unused * format * fix convert * add tests * use voice argument * set default cfg_scale * format * format * remove unused * reove unused * fix quant predicates and convert * add multi-speaker
1 parent bc8cbf4 commit f3cd320

File tree

13 files changed

+2519
-11
lines changed

13 files changed

+2519
-11
lines changed

mlx_audio/tts/generate.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ def generate_audio(
208208
voice: str = "af_heart",
209209
speed: float = 1.0,
210210
lang_code: str = "a",
211+
cfg_scale: Optional[float] = None,
212+
ddpm_steps: Optional[int] = None,
211213
ref_audio: Optional[str] = None,
212214
ref_text: Optional[str] = None,
213215
stt_model: Optional[Union[str, nn.Module]] = "mlx-community/whisper-large-v3-turbo",
@@ -299,13 +301,15 @@ def generate_audio(
299301
f"\033[94mLanguage:\033[0m {lang_code}"
300302
)
301303

302-
results = model.generate(
304+
gen_kwargs = dict(
303305
text=text,
304306
voice=voice,
305307
speed=speed,
306308
lang_code=lang_code,
307309
ref_audio=ref_audio,
308310
ref_text=ref_text,
311+
cfg_scale=cfg_scale,
312+
ddpm_steps=ddpm_steps,
309313
temperature=temperature,
310314
max_tokens=max_tokens,
311315
verbose=verbose,
@@ -314,6 +318,8 @@ def generate_audio(
314318
**kwargs,
315319
)
316320

321+
results = model.generate(**gen_kwargs)
322+
317323
audio_list = []
318324
file_name = f"{file_prefix}.{audio_format}"
319325
for i, result in enumerate(results):
@@ -393,6 +399,19 @@ def parse_args():
393399
help="Text to generate (leave blank to input via stdin)",
394400
)
395401
parser.add_argument("--voice", type=str, default=None, help="Voice name")
402+
parser.add_argument(
403+
"--cfg_scale",
404+
type=float,
405+
default=1.5,
406+
help="Classifier-free guidance scale. Lower (≈1.0-1.5) is often more stable.",
407+
)
408+
parser.add_argument(
409+
"--ddpm_steps",
410+
type=int,
411+
default=None,
412+
help="Override diffusion steps. Higher = better quality, slower (try 30-50).",
413+
)
414+
396415
parser.add_argument("--speed", type=float, default=1.0, help="Speed of the audio")
397416
parser.add_argument(
398417
"--gender", type=str, default="male", help="Gender of the voice [male, female]"

mlx_audio/tts/models/sesame/sesame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def __init__(
441441

442442
self._sample_rate = mimi.cfg.sample_rate
443443

444-
def model_quant_predicate(self, p, m, config):
444+
def model_quant_predicate(self, p, m):
445445
"""
446446
Model modules to skip during quantization
447447
"""

mlx_audio/tts/models/spark/spark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def sample_rate(self):
9595
def layers(self):
9696
return self.model.layers
9797

98-
def model_quant_predicate(self, p, m, config):
98+
def model_quant_predicate(self, p, m):
9999
"""
100100
Model modules to skip during quantization
101101
"""
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from .config import (
2+
AcousticTokenizerConfig,
3+
DiffusionHeadConfig,
4+
ModelConfig,
5+
Qwen2DecoderConfig,
6+
)
7+
from .vibevoice import Model
8+
9+
__all__ = [
10+
"Model",
11+
"ModelConfig",
12+
"AcousticTokenizerConfig",
13+
"DiffusionHeadConfig",
14+
"Qwen2DecoderConfig",
15+
]

0 commit comments

Comments
 (0)