11import argparse
2- import json
32import os
43import sys
4+ from typing import Optional
55
66import mlx .core as mx
77import soundfile as sf
1010from .utils import load_model
1111
1212
13- def parse_args ():
14- parser = argparse .ArgumentParser ()
15- parser .add_argument (
16- "--model" ,
17- type = str ,
18- default = "prince-canuma/Kokoro-82M" ,
19- help = "Path or repo id of the model" ,
20- )
21- parser .add_argument (
22- "--text" ,
23- type = str ,
24- default = None ,
25- help = "Text to generate (leave blank to input via stdin)" ,
26- )
27- parser .add_argument ("--voice" , type = str , default = "af_heart" , help = "Voice name" )
28- parser .add_argument ("--speed" , type = float , default = 1.0 , help = "Speed of the audio" )
29- parser .add_argument ("--lang_code" , type = str , default = "a" , help = "Language code" )
30- parser .add_argument (
31- "--file_prefix" , type = str , default = "audio" , help = "Output file name prefix"
32- )
33- parser .add_argument ("--verbose" , action = "store_false" , help = "Print verbose output" )
34- parser .add_argument (
35- "--join_audio" , action = "store_true" , help = "Join all audio files into one"
36- )
37- parser .add_argument ("--play" , action = "store_true" , help = "Play the output audio" )
38- parser .add_argument (
39- "--ref_audio" , type = str , default = None , help = "Path to reference audio"
40- )
41- parser .add_argument (
42- "--ref_text" , type = str , default = None , help = "Caption for reference audio"
43- )
44- args = parser .parse_args ()
45-
46- if args .text is None :
47- if not sys .stdin .isatty ():
48- args .text = sys .stdin .read ().strip ()
49- else :
50- print ("Please enter the text to generate:" )
51- args .text = input ("> " ).strip ()
52-
53- return args
54-
55-
56- def main ():
57- args = parse_args ()
13+ def generate_audio (
14+ text : str ,
15+ model_path : str = "prince-canuma/Kokoro-82M" ,
16+ voice : str = "af_heart" ,
17+ speed : float = 1.0 ,
18+ lang_code : str = "a" ,
19+ ref_audio : Optional [str ] = None ,
20+ ref_text : Optional [str ] = None ,
21+ file_prefix : str = "audio" ,
22+ audio_format : str = "wav" ,
23+ sample_rate : int = 24000 ,
24+ join_audio : bool = False ,
25+ play : bool = False ,
26+ verbose : bool = True ,
27+ from_cli : bool = False ,
28+ ) -> None :
29+ """
30+ Generates audio from text using a specified TTS model.
31+
32+ Parameters:
33+ - text (str): The input text to be converted to speech.
34+ - model (str): The TTS model to use.
35+ - voice (str): The voice style to use.
36+ - speed (float): Playback speed multiplier.
37+ - lang_code (str): The language code.
38+ - ref_audio (mx.array): Reference audio you would like to clone the voice from.
39+ - ref_text (str): Caption for reference audio.
40+ - file_prefix (str): The output file path without extension.
41+ - audio_format (str): Output audio format (e.g., "wav", "flac").
42+ - sample_rate (int): Sampling rate in Hz.
43+ - join_audio (bool): Whether to join multiple audio files into one.
44+ - play (bool): Whether to play the generated audio.
45+ - verbose (bool): Whether to print status messages.
46+
47+ Returns:
48+ - None: The function writes the generated audio to a file.
49+ """
5850 try :
59- # load reference audio for voice matching if specified
51+ # Load reference audio for voice matching if specified
6052
61- ref_audio = None
62- ref_text = None
63-
64- if args .ref_audio :
65- if not os .path .exists (args .ref_audio ):
66- raise FileNotFoundError (
67- f"Reference audio file not found: { args .ref_audio } "
68- )
69- if not args .ref_text :
53+ if ref_audio :
54+ if not os .path .exists (ref_audio ):
55+ raise FileNotFoundError (f"Reference audio file not found: { ref_audio } " )
56+ if not ref_text :
7057 raise ValueError (
7158 "Reference text is required when using reference audio."
7259 )
7360
74- ref_audio , ref_sr = sf .read (args . ref_audio )
61+ ref_audio , ref_sr = sf .read (ref_audio )
7562 if ref_sr != 24000 :
7663 raise ValueError (
7764 f"Reference audio sample rate must be 24000 Hz, but got { ref_sr } Hz."
7865 )
7966 ref_audio = mx .array (ref_audio , dtype = mx .float32 )
80- ref_text = args .ref_text
8167
82- player = AudioPlayer () if args .play else None
68+ # Load AudioPlayer
69+ player = AudioPlayer () if play else None
8370
84- model = load_model (model_path = args .model )
71+ # Load model
72+ model = load_model (model_path = model_path )
8573 print (
86- f"\n \033 [94mModel:\033 [0m { args . model } \n "
87- f"\033 [94mText:\033 [0m { args . text } \n "
88- f"\033 [94mVoice:\033 [0m { args . voice } \n "
89- f"\033 [94mSpeed:\033 [0m { args . speed } x\n "
90- f"\033 [94mLanguage:\033 [0m { args . lang_code } "
74+ f"\n \033 [94mModel:\033 [0m { model_path } \n "
75+ f"\033 [94mText:\033 [0m { text } \n "
76+ f"\033 [94mVoice:\033 [0m { voice } \n "
77+ f"\033 [94mSpeed:\033 [0m { speed } x\n "
78+ f"\033 [94mLanguage:\033 [0m { lang_code } "
9179 )
92- print ( "==========" )
80+
9381 results = model .generate (
94- text = args . text ,
95- voice = args . voice ,
96- speed = args . speed ,
97- lang_code = args . lang_code ,
82+ text = text ,
83+ voice = voice ,
84+ speed = speed ,
85+ lang_code = lang_code ,
9886 ref_audio = ref_audio ,
9987 ref_text = ref_text ,
10088 verbose = True ,
10189 )
102- print (
103- f"\033 [92mAudio generated successfully, saving to\033 [0m { args .file_prefix } !"
104- )
10590
10691 audio_list = []
92+ file_name = f"{ file_prefix } .{ audio_format } "
10793 for i , result in enumerate (results ):
108- if args . play :
94+ if play :
10995 player .queue_audio (result .audio )
110- if args . join_audio :
96+ if join_audio :
11197 audio_list .append (result .audio )
98+
11299 else :
113- sf .write (f"{ args .file_prefix } _{ i :03d} .wav" , result .audio , 24000 )
100+ file_name = f"{ file_prefix } _{ i :03d} .{ audio_format } "
101+ sf .write (file_name , result .audio , 24000 )
102+
103+ if verbose :
114104
115- if args .verbose :
116105 print ("==========" )
117106 print (f"Duration: { result .audio_duration } " )
118107 print (
@@ -127,15 +116,18 @@ def main():
127116 print (f"Real-time factor: { result .real_time_factor :.2f} x" )
128117 print (f"Processing time: { result .processing_time_seconds :.2f} s" )
129118 print (f"Peak memory usage: { result .peak_memory_usage :.2f} GB" )
119+ print (f"✅ Audio successfully generated and saving as: { file_name } " )
130120
131- if args .join_audio :
132- print (f"Joining { len (audio_list )} audio files" )
121+ if join_audio :
122+ if verbose :
123+ print (f"Joining { len (audio_list )} audio files" )
133124 audio = mx .concatenate (audio_list , axis = 0 )
134- sf .write (f"{ args . file_prefix } .wav " , audio , 24000 )
125+ sf .write (f"{ file_prefix } .{ audio_format } " , audio , 24000 )
135126
136- if args . play :
127+ if play :
137128 player .wait_for_drain ()
138129 player .stop ()
130+
139131 except ImportError as e :
140132 print (f"Import error: { e } " )
141133 print (
@@ -148,5 +140,75 @@ def main():
148140 traceback .print_exc ()
149141
150142
143+ def parse_args ():
144+ parser = argparse .ArgumentParser (description = "Generate audio from text using TTS." )
145+ parser .add_argument (
146+ "--model" ,
147+ type = str ,
148+ default = "mlx-community/Kokoro-82M-bf16" ,
149+ help = "Path or repo id of the model" ,
150+ )
151+ parser .add_argument (
152+ "--text" ,
153+ type = str ,
154+ default = None ,
155+ help = "Text to generate (leave blank to input via stdin)" ,
156+ )
157+ parser .add_argument ("--voice" , type = str , default = "af_heart" , help = "Voice name" )
158+ parser .add_argument ("--speed" , type = float , default = 1.0 , help = "Speed of the audio" )
159+ parser .add_argument ("--lang_code" , type = str , default = "a" , help = "Language code" )
160+ parser .add_argument (
161+ "--file_prefix" , type = str , default = "audio" , help = "Output file name prefix"
162+ )
163+ parser .add_argument ("--verbose" , action = "store_false" , help = "Print verbose output" )
164+ parser .add_argument (
165+ "--join_audio" , action = "store_true" , help = "Join all audio files into one"
166+ )
167+ parser .add_argument ("--play" , action = "store_true" , help = "Play the output audio" )
168+ parser .add_argument (
169+ "--audio_format" , type = str , default = "wav" , help = "Output audio format"
170+ )
171+ parser .add_argument (
172+ "--sample_rate" , type = int , default = 24000 , help = "Audio sample rate in Hz"
173+ )
174+ parser .add_argument (
175+ "--ref_audio" , type = str , default = None , help = "Path to reference audio"
176+ )
177+ parser .add_argument (
178+ "--ref_text" , type = str , default = None , help = "Caption for reference audio"
179+ )
180+
181+ args = parser .parse_args ()
182+
183+ if args .text is None :
184+ if not sys .stdin .isatty ():
185+ args .text = sys .stdin .read ().strip ()
186+ else :
187+ print ("Please enter the text to generate:" )
188+ args .text = input ("> " ).strip ()
189+
190+ return args
191+
192+
193+ def main ():
194+ args = parse_args ()
195+
196+ generate_audio (
197+ text = args .text ,
198+ model_path = args .model ,
199+ voice = args .voice ,
200+ speed = args .speed ,
201+ lang_code = args .lang_code ,
202+ ref_audio = args .ref_audio ,
203+ ref_text = args .ref_text ,
204+ file_prefix = args .file_prefix ,
205+ audio_format = args .audio_format ,
206+ sample_rate = args .sample_rate ,
207+ join_audio = args .join_audio ,
208+ play = args .play ,
209+ verbose = args .verbose ,
210+ )
211+
212+
151213if __name__ == "__main__" :
152214 main ()
0 commit comments