Skip to content

Commit 47dc0bb

Browse files
sergenesBlaizzy
andauthored
🔥 Feature: External API for Audiobook Generation (#19)
* allow external scripts to generate speech audio using `mlx_audio`, e.g. for audiobook projects * Refactor TTS Generation: Unified CLI & Script Logic, Eliminating Duplication * reformat with pre-commit run --all * rebasing with the main * added sample_rate to join audio * Update README.md * format * fix formatting * add os * update doc string --------- Co-authored-by: Prince Canuma <prince.gdt@gmail.com>
1 parent 9d3c69b commit 47dc0bb

File tree

2 files changed

+167
-81
lines changed

2 files changed

+167
-81
lines changed

README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,30 @@ mlx_audio.tts.generate --text "Hello, world" --file_prefix hello
3838
mlx_audio.tts.generate --text "Hello, world" --speed 1.4
3939
```
4040

41+
### How to call from python
42+
43+
To generate audio with an LLM use:
44+
45+
```python
46+
from mlx_audio.tts.generate import generate_audio
47+
48+
# Example: Generate an audiobook chapter as audio
49+
generate_audio(
50+
text="In the beginning, the universe was created...",
51+
model_path="prince-canuma/Kokoro-82M",
52+
voice="af_heart",
53+
speed=1.2,
54+
lang_code="en",
55+
file_prefix="audiobook_chapter1",
56+
audio_format="wav",
57+
sample_rate=24000,
58+
join_audio=True,
59+
verbose=True # Set to False to disable print messages
60+
)
61+
62+
print("Audiobook chapter successfully generated!")
63+
64+
```
4165

4266
### Web Interface & API Server
4367

mlx_audio/tts/generate.py

Lines changed: 143 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
2-
import json
32
import os
43
import sys
4+
from typing import Optional
55

66
import mlx.core as mx
77
import soundfile as sf
@@ -10,109 +10,98 @@
1010
from .utils import load_model
1111

1212

13-
def parse_args():
14-
parser = argparse.ArgumentParser()
15-
parser.add_argument(
16-
"--model",
17-
type=str,
18-
default="prince-canuma/Kokoro-82M",
19-
help="Path or repo id of the model",
20-
)
21-
parser.add_argument(
22-
"--text",
23-
type=str,
24-
default=None,
25-
help="Text to generate (leave blank to input via stdin)",
26-
)
27-
parser.add_argument("--voice", type=str, default="af_heart", help="Voice name")
28-
parser.add_argument("--speed", type=float, default=1.0, help="Speed of the audio")
29-
parser.add_argument("--lang_code", type=str, default="a", help="Language code")
30-
parser.add_argument(
31-
"--file_prefix", type=str, default="audio", help="Output file name prefix"
32-
)
33-
parser.add_argument("--verbose", action="store_false", help="Print verbose output")
34-
parser.add_argument(
35-
"--join_audio", action="store_true", help="Join all audio files into one"
36-
)
37-
parser.add_argument("--play", action="store_true", help="Play the output audio")
38-
parser.add_argument(
39-
"--ref_audio", type=str, default=None, help="Path to reference audio"
40-
)
41-
parser.add_argument(
42-
"--ref_text", type=str, default=None, help="Caption for reference audio"
43-
)
44-
args = parser.parse_args()
45-
46-
if args.text is None:
47-
if not sys.stdin.isatty():
48-
args.text = sys.stdin.read().strip()
49-
else:
50-
print("Please enter the text to generate:")
51-
args.text = input("> ").strip()
52-
53-
return args
54-
55-
56-
def main():
57-
args = parse_args()
13+
def generate_audio(
14+
text: str,
15+
model_path: str = "prince-canuma/Kokoro-82M",
16+
voice: str = "af_heart",
17+
speed: float = 1.0,
18+
lang_code: str = "a",
19+
ref_audio: Optional[str] = None,
20+
ref_text: Optional[str] = None,
21+
file_prefix: str = "audio",
22+
audio_format: str = "wav",
23+
sample_rate: int = 24000,
24+
join_audio: bool = False,
25+
play: bool = False,
26+
verbose: bool = True,
27+
from_cli: bool = False,
28+
) -> None:
29+
"""
30+
Generates audio from text using a specified TTS model.
31+
32+
Parameters:
33+
- text (str): The input text to be converted to speech.
34+
- model (str): The TTS model to use.
35+
- voice (str): The voice style to use.
36+
- speed (float): Playback speed multiplier.
37+
- lang_code (str): The language code.
38+
- ref_audio (mx.array): Reference audio you would like to clone the voice from.
39+
- ref_text (str): Caption for reference audio.
40+
- file_prefix (str): The output file path without extension.
41+
- audio_format (str): Output audio format (e.g., "wav", "flac").
42+
- sample_rate (int): Sampling rate in Hz.
43+
- join_audio (bool): Whether to join multiple audio files into one.
44+
- play (bool): Whether to play the generated audio.
45+
- verbose (bool): Whether to print status messages.
46+
47+
Returns:
48+
- None: The function writes the generated audio to a file.
49+
"""
5850
try:
59-
# load reference audio for voice matching if specified
51+
# Load reference audio for voice matching if specified
6052

61-
ref_audio = None
62-
ref_text = None
63-
64-
if args.ref_audio:
65-
if not os.path.exists(args.ref_audio):
66-
raise FileNotFoundError(
67-
f"Reference audio file not found: {args.ref_audio}"
68-
)
69-
if not args.ref_text:
53+
if ref_audio:
54+
if not os.path.exists(ref_audio):
55+
raise FileNotFoundError(f"Reference audio file not found: {ref_audio}")
56+
if not ref_text:
7057
raise ValueError(
7158
"Reference text is required when using reference audio."
7259
)
7360

74-
ref_audio, ref_sr = sf.read(args.ref_audio)
61+
ref_audio, ref_sr = sf.read(ref_audio)
7562
if ref_sr != 24000:
7663
raise ValueError(
7764
f"Reference audio sample rate must be 24000 Hz, but got {ref_sr} Hz."
7865
)
7966
ref_audio = mx.array(ref_audio, dtype=mx.float32)
80-
ref_text = args.ref_text
8167

82-
player = AudioPlayer() if args.play else None
68+
# Load AudioPlayer
69+
player = AudioPlayer() if play else None
8370

84-
model = load_model(model_path=args.model)
71+
# Load model
72+
model = load_model(model_path=model_path)
8573
print(
86-
f"\n\033[94mModel:\033[0m {args.model}\n"
87-
f"\033[94mText:\033[0m {args.text}\n"
88-
f"\033[94mVoice:\033[0m {args.voice}\n"
89-
f"\033[94mSpeed:\033[0m {args.speed}x\n"
90-
f"\033[94mLanguage:\033[0m {args.lang_code}"
74+
f"\n\033[94mModel:\033[0m {model_path}\n"
75+
f"\033[94mText:\033[0m {text}\n"
76+
f"\033[94mVoice:\033[0m {voice}\n"
77+
f"\033[94mSpeed:\033[0m {speed}x\n"
78+
f"\033[94mLanguage:\033[0m {lang_code}"
9179
)
92-
print("==========")
80+
9381
results = model.generate(
94-
text=args.text,
95-
voice=args.voice,
96-
speed=args.speed,
97-
lang_code=args.lang_code,
82+
text=text,
83+
voice=voice,
84+
speed=speed,
85+
lang_code=lang_code,
9886
ref_audio=ref_audio,
9987
ref_text=ref_text,
10088
verbose=True,
10189
)
102-
print(
103-
f"\033[92mAudio generated successfully, saving to\033[0m {args.file_prefix}!"
104-
)
10590

10691
audio_list = []
92+
file_name = f"{file_prefix}.{audio_format}"
10793
for i, result in enumerate(results):
108-
if args.play:
94+
if play:
10995
player.queue_audio(result.audio)
110-
if args.join_audio:
96+
if join_audio:
11197
audio_list.append(result.audio)
98+
11299
else:
113-
sf.write(f"{args.file_prefix}_{i:03d}.wav", result.audio, 24000)
100+
file_name = f"{file_prefix}_{i:03d}.{audio_format}"
101+
sf.write(file_name, result.audio, 24000)
102+
103+
if verbose:
114104

115-
if args.verbose:
116105
print("==========")
117106
print(f"Duration: {result.audio_duration}")
118107
print(
@@ -127,15 +116,18 @@ def main():
127116
print(f"Real-time factor: {result.real_time_factor:.2f}x")
128117
print(f"Processing time: {result.processing_time_seconds:.2f}s")
129118
print(f"Peak memory usage: {result.peak_memory_usage:.2f}GB")
119+
print(f"✅ Audio successfully generated and saving as: {file_name}")
130120

131-
if args.join_audio:
132-
print(f"Joining {len(audio_list)} audio files")
121+
if join_audio:
122+
if verbose:
123+
print(f"Joining {len(audio_list)} audio files")
133124
audio = mx.concatenate(audio_list, axis=0)
134-
sf.write(f"{args.file_prefix}.wav", audio, 24000)
125+
sf.write(f"{file_prefix}.{audio_format}", audio, 24000)
135126

136-
if args.play:
127+
if play:
137128
player.wait_for_drain()
138129
player.stop()
130+
139131
except ImportError as e:
140132
print(f"Import error: {e}")
141133
print(
@@ -148,5 +140,75 @@ def main():
148140
traceback.print_exc()
149141

150142

143+
def parse_args():
144+
parser = argparse.ArgumentParser(description="Generate audio from text using TTS.")
145+
parser.add_argument(
146+
"--model",
147+
type=str,
148+
default="mlx-community/Kokoro-82M-bf16",
149+
help="Path or repo id of the model",
150+
)
151+
parser.add_argument(
152+
"--text",
153+
type=str,
154+
default=None,
155+
help="Text to generate (leave blank to input via stdin)",
156+
)
157+
parser.add_argument("--voice", type=str, default="af_heart", help="Voice name")
158+
parser.add_argument("--speed", type=float, default=1.0, help="Speed of the audio")
159+
parser.add_argument("--lang_code", type=str, default="a", help="Language code")
160+
parser.add_argument(
161+
"--file_prefix", type=str, default="audio", help="Output file name prefix"
162+
)
163+
parser.add_argument("--verbose", action="store_false", help="Print verbose output")
164+
parser.add_argument(
165+
"--join_audio", action="store_true", help="Join all audio files into one"
166+
)
167+
parser.add_argument("--play", action="store_true", help="Play the output audio")
168+
parser.add_argument(
169+
"--audio_format", type=str, default="wav", help="Output audio format"
170+
)
171+
parser.add_argument(
172+
"--sample_rate", type=int, default=24000, help="Audio sample rate in Hz"
173+
)
174+
parser.add_argument(
175+
"--ref_audio", type=str, default=None, help="Path to reference audio"
176+
)
177+
parser.add_argument(
178+
"--ref_text", type=str, default=None, help="Caption for reference audio"
179+
)
180+
181+
args = parser.parse_args()
182+
183+
if args.text is None:
184+
if not sys.stdin.isatty():
185+
args.text = sys.stdin.read().strip()
186+
else:
187+
print("Please enter the text to generate:")
188+
args.text = input("> ").strip()
189+
190+
return args
191+
192+
193+
def main():
194+
args = parse_args()
195+
196+
generate_audio(
197+
text=args.text,
198+
model_path=args.model,
199+
voice=args.voice,
200+
speed=args.speed,
201+
lang_code=args.lang_code,
202+
ref_audio=args.ref_audio,
203+
ref_text=args.ref_text,
204+
file_prefix=args.file_prefix,
205+
audio_format=args.audio_format,
206+
sample_rate=args.sample_rate,
207+
join_audio=args.join_audio,
208+
play=args.play,
209+
verbose=args.verbose,
210+
)
211+
212+
151213
if __name__ == "__main__":
152214
main()

0 commit comments

Comments
 (0)