Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,30 @@ mlx_audio.tts.generate --text "Hello, world" --file_prefix hello
mlx_audio.tts.generate --text "Hello, world" --speed 1.4
```

### How to call from python

To generate audio with an LLM use:

```python
from mlx_audio.tts.generate import generate_audio

# Example: Generate an audiobook chapter as audio
generate_audio(
text="In the beginning, the universe was created...",
model_path="prince-canuma/Kokoro-82M",
voice="af_heart",
speed=1.2,
lang_code="en",
file_prefix="audiobook_chapter1",
audio_format="wav",
sample_rate=24000,
join_audio=True,
verbose=True # Set to False to disable print messages
)

print("Audiobook chapter successfully generated!")

```

### Web Interface & API Server

Expand Down
224 changes: 143 additions & 81 deletions mlx_audio/tts/generate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import json
import os
import sys
from typing import Optional

import mlx.core as mx
import soundfile as sf
Expand All @@ -10,109 +10,98 @@
from .utils import load_model


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="prince-canuma/Kokoro-82M",
help="Path or repo id of the model",
)
parser.add_argument(
"--text",
type=str,
default=None,
help="Text to generate (leave blank to input via stdin)",
)
parser.add_argument("--voice", type=str, default="af_heart", help="Voice name")
parser.add_argument("--speed", type=float, default=1.0, help="Speed of the audio")
parser.add_argument("--lang_code", type=str, default="a", help="Language code")
parser.add_argument(
"--file_prefix", type=str, default="audio", help="Output file name prefix"
)
parser.add_argument("--verbose", action="store_false", help="Print verbose output")
parser.add_argument(
"--join_audio", action="store_true", help="Join all audio files into one"
)
parser.add_argument("--play", action="store_true", help="Play the output audio")
parser.add_argument(
"--ref_audio", type=str, default=None, help="Path to reference audio"
)
parser.add_argument(
"--ref_text", type=str, default=None, help="Caption for reference audio"
)
args = parser.parse_args()

if args.text is None:
if not sys.stdin.isatty():
args.text = sys.stdin.read().strip()
else:
print("Please enter the text to generate:")
args.text = input("> ").strip()

return args


def main():
args = parse_args()
def generate_audio(
text: str,
model_path: str = "prince-canuma/Kokoro-82M",
voice: str = "af_heart",
speed: float = 1.0,
lang_code: str = "a",
ref_audio: Optional[str] = None,
ref_text: Optional[str] = None,
file_prefix: str = "audio",
audio_format: str = "wav",
sample_rate: int = 24000,
join_audio: bool = False,
play: bool = False,
verbose: bool = True,
from_cli: bool = False,
) -> None:
"""
Generates audio from text using a specified TTS model.

Parameters:
- text (str): The input text to be converted to speech.
- model (str): The TTS model to use.
- voice (str): The voice style to use.
- speed (float): Playback speed multiplier.
- lang_code (str): The language code.
- ref_audio (mx.array): Reference audio you would like to clone the voice from.
- ref_text (str): Caption for reference audio.
- file_prefix (str): The output file path without extension.
- audio_format (str): Output audio format (e.g., "wav", "flac").
- sample_rate (int): Sampling rate in Hz.
- join_audio (bool): Whether to join multiple audio files into one.
- play (bool): Whether to play the generated audio.
- verbose (bool): Whether to print status messages.

Returns:
- None: The function writes the generated audio to a file.
"""
try:
# load reference audio for voice matching if specified
# Load reference audio for voice matching if specified

ref_audio = None
ref_text = None

if args.ref_audio:
if not os.path.exists(args.ref_audio):
raise FileNotFoundError(
f"Reference audio file not found: {args.ref_audio}"
)
if not args.ref_text:
if ref_audio:
if not os.path.exists(ref_audio):
raise FileNotFoundError(f"Reference audio file not found: {ref_audio}")
if not ref_text:
raise ValueError(
"Reference text is required when using reference audio."
)

ref_audio, ref_sr = sf.read(args.ref_audio)
ref_audio, ref_sr = sf.read(ref_audio)
if ref_sr != 24000:
raise ValueError(
f"Reference audio sample rate must be 24000 Hz, but got {ref_sr} Hz."
)
ref_audio = mx.array(ref_audio, dtype=mx.float32)
ref_text = args.ref_text

player = AudioPlayer() if args.play else None
# Load AudioPlayer
player = AudioPlayer() if play else None

model = load_model(model_path=args.model)
# Load model
model = load_model(model_path=model_path)
print(
f"\n\033[94mModel:\033[0m {args.model}\n"
f"\033[94mText:\033[0m {args.text}\n"
f"\033[94mVoice:\033[0m {args.voice}\n"
f"\033[94mSpeed:\033[0m {args.speed}x\n"
f"\033[94mLanguage:\033[0m {args.lang_code}"
f"\n\033[94mModel:\033[0m {model_path}\n"
f"\033[94mText:\033[0m {text}\n"
f"\033[94mVoice:\033[0m {voice}\n"
f"\033[94mSpeed:\033[0m {speed}x\n"
f"\033[94mLanguage:\033[0m {lang_code}"
)
print("==========")

results = model.generate(
text=args.text,
voice=args.voice,
speed=args.speed,
lang_code=args.lang_code,
text=text,
voice=voice,
speed=speed,
lang_code=lang_code,
ref_audio=ref_audio,
ref_text=ref_text,
verbose=True,
)
print(
f"\033[92mAudio generated successfully, saving to\033[0m {args.file_prefix}!"
)

audio_list = []
file_name = f"{file_prefix}.{audio_format}"
for i, result in enumerate(results):
if args.play:
if play:
player.queue_audio(result.audio)
if args.join_audio:
if join_audio:
audio_list.append(result.audio)

else:
sf.write(f"{args.file_prefix}_{i:03d}.wav", result.audio, 24000)
file_name = f"{file_prefix}_{i:03d}.{audio_format}"
sf.write(file_name, result.audio, 24000)

if verbose:

if args.verbose:
print("==========")
print(f"Duration: {result.audio_duration}")
print(
Expand All @@ -127,15 +116,18 @@ def main():
print(f"Real-time factor: {result.real_time_factor:.2f}x")
print(f"Processing time: {result.processing_time_seconds:.2f}s")
print(f"Peak memory usage: {result.peak_memory_usage:.2f}GB")
print(f"βœ… Audio successfully generated and saving as: {file_name}")

if args.join_audio:
print(f"Joining {len(audio_list)} audio files")
if join_audio:
if verbose:
print(f"Joining {len(audio_list)} audio files")
audio = mx.concatenate(audio_list, axis=0)
sf.write(f"{args.file_prefix}.wav", audio, 24000)
sf.write(f"{file_prefix}.{audio_format}", audio, 24000)

if args.play:
if play:
player.wait_for_drain()
player.stop()

except ImportError as e:
print(f"Import error: {e}")
print(
Expand All @@ -148,5 +140,75 @@ def main():
traceback.print_exc()


def parse_args():
parser = argparse.ArgumentParser(description="Generate audio from text using TTS.")
parser.add_argument(
"--model",
type=str,
default="mlx-community/Kokoro-82M-bf16",
help="Path or repo id of the model",
)
parser.add_argument(
"--text",
type=str,
default=None,
help="Text to generate (leave blank to input via stdin)",
)
parser.add_argument("--voice", type=str, default="af_heart", help="Voice name")
parser.add_argument("--speed", type=float, default=1.0, help="Speed of the audio")
parser.add_argument("--lang_code", type=str, default="a", help="Language code")
parser.add_argument(
"--file_prefix", type=str, default="audio", help="Output file name prefix"
)
parser.add_argument("--verbose", action="store_false", help="Print verbose output")
parser.add_argument(
"--join_audio", action="store_true", help="Join all audio files into one"
)
parser.add_argument("--play", action="store_true", help="Play the output audio")
parser.add_argument(
"--audio_format", type=str, default="wav", help="Output audio format"
)
parser.add_argument(
"--sample_rate", type=int, default=24000, help="Audio sample rate in Hz"
)
parser.add_argument(
"--ref_audio", type=str, default=None, help="Path to reference audio"
)
parser.add_argument(
"--ref_text", type=str, default=None, help="Caption for reference audio"
)

args = parser.parse_args()

if args.text is None:
if not sys.stdin.isatty():
args.text = sys.stdin.read().strip()
else:
print("Please enter the text to generate:")
args.text = input("> ").strip()

return args


def main():
args = parse_args()

generate_audio(
text=args.text,
model_path=args.model,
voice=args.voice,
speed=args.speed,
lang_code=args.lang_code,
ref_audio=args.ref_audio,
ref_text=args.ref_text,
file_prefix=args.file_prefix,
audio_format=args.audio_format,
sample_rate=args.sample_rate,
join_audio=args.join_audio,
play=args.play,
verbose=args.verbose,
)


if __name__ == "__main__":
main()