|
| 1 | +# ruff: noqa: E402 |
| 2 | +# %% |
| 3 | +print("Starting server...\n") |
| 4 | +import tts_webui.utils.setup_or_recover as setup_or_recover |
| 5 | + |
| 6 | +setup_or_recover.setup_or_recover() |
| 7 | + |
| 8 | +import tts_webui.utils.dotenv_init as dotenv_init |
| 9 | + |
| 10 | +dotenv_init.init() |
| 11 | + |
| 12 | +import os |
| 13 | +import gradio as gr |
| 14 | +from tts_webui.utils.suppress_warnings import suppress_warnings |
| 15 | + |
| 16 | +suppress_warnings() |
| 17 | + |
| 18 | +from tts_webui.config.load_config import default_config |
| 19 | +from tts_webui.config.config import config |
| 20 | + |
| 21 | +from tts_webui.css.css import full_css |
| 22 | +from tts_webui.history_tab.collections_directories_atom import ( |
| 23 | + collections_directories_atom, |
| 24 | +) |
| 25 | + |
| 26 | + |
| 27 | +from tts_webui.utils.generic_error_tab_advanced import generic_error_tab_advanced |
| 28 | +from tts_webui.extensions_loader.interface_extensions import ( |
| 29 | + extension_list_tab, |
| 30 | + handle_extension_class, |
| 31 | +) |
| 32 | +from tts_webui.extensions_loader.decorator_extensions import ( |
| 33 | + extension_decorator_list_tab, |
| 34 | +) |
| 35 | + |
| 36 | + |
| 37 | +def reload_config_and_restart_ui(): |
| 38 | + os._exit(0) |
| 39 | + # print("Reloading config and restarting UI...") |
| 40 | + # config = load_config() |
| 41 | + # gradio_interface_options = config["gradio_interface_options"] if "gradio_interface_options" in config else {} |
| 42 | + # demo.close() |
| 43 | + # time.sleep(1) |
| 44 | + # demo.launch(**gradio_interface_options) |
| 45 | + |
| 46 | + |
| 47 | +gradio_interface_options = ( |
| 48 | + config["gradio_interface_options"] |
| 49 | + if "gradio_interface_options" in config |
| 50 | + else default_config |
| 51 | +) |
| 52 | + |
| 53 | + |
| 54 | +import time |
| 55 | +import importlib |
| 56 | + |
| 57 | + |
| 58 | +def run_tab(module_name, function_name, name, requirements=None): |
| 59 | + print(f"Loading {name} tab...") |
| 60 | + start_time = time.time() |
| 61 | + try: |
| 62 | + module = importlib.import_module(module_name) |
| 63 | + func = getattr(module, function_name) |
| 64 | + func() |
| 65 | + except Exception as e: |
| 66 | + generic_error_tab_advanced(e, name=name, requirements=requirements) |
| 67 | + finally: |
| 68 | + elapsed_time = time.time() - start_time |
| 69 | + print(f" Done in {elapsed_time:.2f} seconds. ({name})\n") |
| 70 | + |
| 71 | + |
| 72 | +def load_tabs(list_of_tabs): |
| 73 | + for tab in list_of_tabs: |
| 74 | + module_name, function_name, name = tab[:3] |
| 75 | + requirements = tab[3] if len(tab) > 3 else None |
| 76 | + run_tab(module_name, function_name, name, requirements) |
| 77 | + |
| 78 | + |
| 79 | +def main_ui(theme_choice="Base"): |
| 80 | + themes = { |
| 81 | + "Base": gr.themes.Base, |
| 82 | + "Default": gr.themes.Default, |
| 83 | + "Monochrome": gr.themes.Monochrome, |
| 84 | + } |
| 85 | + theme: gr.themes.Base = themes[theme_choice]( |
| 86 | + # primary_hue="blue", |
| 87 | + primary_hue="sky", |
| 88 | + secondary_hue="sky", |
| 89 | + neutral_hue="neutral", |
| 90 | + font=[ |
| 91 | + gr.themes.GoogleFont("Inter"), |
| 92 | + "ui-sans-serif", |
| 93 | + "system-ui", |
| 94 | + "sans-serif", |
| 95 | + ], |
| 96 | + ) |
| 97 | + theme.set( |
| 98 | + embed_radius="*radius_sm", |
| 99 | + block_label_radius="*radius_sm", |
| 100 | + block_label_right_radius="*radius_sm", |
| 101 | + block_radius="*radius_sm", |
| 102 | + block_title_radius="*radius_sm", |
| 103 | + container_radius="*radius_sm", |
| 104 | + checkbox_border_radius="*radius_sm", |
| 105 | + input_radius="*radius_sm", |
| 106 | + table_radius="*radius_sm", |
| 107 | + button_large_radius="*radius_sm", |
| 108 | + button_small_radius="*radius_sm", |
| 109 | + button_primary_background_fill_hover="*primary_300", |
| 110 | + button_primary_background_fill_hover_dark="*primary_600", |
| 111 | + button_secondary_background_fill_hover="*secondary_200", |
| 112 | + button_secondary_background_fill_hover_dark="*secondary_600", |
| 113 | + ) |
| 114 | + |
| 115 | + with gr.Blocks( |
| 116 | + css=full_css, |
| 117 | + title="TTS Generation WebUI", |
| 118 | + analytics_enabled=False, # it broke too many times |
| 119 | + theme=theme, |
| 120 | + ) as blocks: |
| 121 | + gr.Markdown( |
| 122 | + """ |
| 123 | + # TTS Generation WebUI (Legacy - Gradio) [React UI](http://localhost:3000) [Feedback / Bug reports](https://forms.gle/2L62owhBsGFzdFBC8) [Discord Server](https://discord.gg/V8BKTVRtJ9) |
| 124 | + ### _(Text To Speech, Audio & Music Generation, Conversion)_ |
| 125 | + """ |
| 126 | + ) |
| 127 | + with gr.Tabs(): |
| 128 | + all_tabs() |
| 129 | + |
| 130 | + return blocks |
| 131 | + |
| 132 | + |
| 133 | +def all_tabs(): |
| 134 | + with gr.Tab("Text-to-Speech"), gr.Tabs(): |
| 135 | + tts_tabs = [ |
| 136 | + ("tts_webui.bark.bark_tab", "bark_tab", "Bark TTS"), |
| 137 | + ( |
| 138 | + "tts_webui.bark.clone.tab_voice_clone", |
| 139 | + "tab_voice_clone", |
| 140 | + "Bark Voice Clone", |
| 141 | + "-r requirements_bark_hubert_quantizer.txt", |
| 142 | + ), |
| 143 | + ( |
| 144 | + "tts_webui.tortoise.tortoise_tab", |
| 145 | + "tortoise_tab", |
| 146 | + "Tortoise TTS", |
| 147 | + ), |
| 148 | + ( |
| 149 | + "tts_webui.seamlessM4T.seamless_tab", |
| 150 | + "seamless_tab", |
| 151 | + "SeamlessM4Tv2Model", |
| 152 | + ), |
| 153 | + ( |
| 154 | + "tts_webui.vall_e_x.vall_e_x_tab", |
| 155 | + "valle_x_tab", |
| 156 | + "Valle-X", |
| 157 | + "-r requirements_vall_e.txt", |
| 158 | + ), |
| 159 | + ("tts_webui.mms.mms_tab", "mms_tab", "MMS"), |
| 160 | + ( |
| 161 | + "tts_webui.maha_tts.maha_tts_tab", |
| 162 | + "maha_tts_tab", |
| 163 | + "MahaTTS", |
| 164 | + "-r requirements_maha_tts.txt", |
| 165 | + ), |
| 166 | + ( |
| 167 | + "tts_webui.styletts2.styletts2_tab", |
| 168 | + "style_tts2_tab", |
| 169 | + "StyleTTS2", |
| 170 | + "-r requirements_styletts2.txt", |
| 171 | + ), |
| 172 | + ] |
| 173 | + load_tabs(tts_tabs) |
| 174 | + |
| 175 | + handle_extension_class("text-to-speech", config) |
| 176 | + with gr.Tab("Audio/Music Generation"), gr.Tabs(): |
| 177 | + audio_music_generation_tabs = [ |
| 178 | + ( |
| 179 | + "tts_webui.stable_audio.stable_audio_tab", |
| 180 | + "stable_audio_tab", |
| 181 | + "Stable Audio", |
| 182 | + "-r requirements_stable_audio.txt", |
| 183 | + ), |
| 184 | + ( |
| 185 | + "tts_webui.magnet.magnet_tab", |
| 186 | + "magnet_tab", |
| 187 | + "MAGNeT", |
| 188 | + "-r requirements_audiocraft.txt", |
| 189 | + ), |
| 190 | + ( |
| 191 | + "tts_webui.musicgen.musicgen_tab", |
| 192 | + "musicgen_tab", |
| 193 | + "MusicGen", |
| 194 | + "-r requirements_audiocraft.txt", |
| 195 | + ), |
| 196 | + ] |
| 197 | + load_tabs(audio_music_generation_tabs) |
| 198 | + |
| 199 | + handle_extension_class("audio-music-generation", config) |
| 200 | + with gr.Tab("Audio Conversion"), gr.Tabs(): |
| 201 | + audio_conversion_tabs = [ |
| 202 | + ( |
| 203 | + "tts_webui.rvc_tab.rvc_tab", |
| 204 | + "rvc_conversion_tab", |
| 205 | + "RVC", |
| 206 | + "-r requirements_rvc.txt", |
| 207 | + ), |
| 208 | + ( |
| 209 | + "tts_webui.rvc_tab.uvr5_tab", |
| 210 | + "uvr5_tab", |
| 211 | + "UVR5", |
| 212 | + "-r requirements_rvc.txt", |
| 213 | + ), |
| 214 | + ( |
| 215 | + "tts_webui.demucs.demucs_tab", |
| 216 | + "demucs_tab", |
| 217 | + "Demucs", |
| 218 | + "-r requirements_audiocraft.txt", |
| 219 | + ), |
| 220 | + ("tts_webui.vocos.vocos_tabs", "vocos_tabs", "Vocos"), |
| 221 | + ] |
| 222 | + load_tabs(audio_conversion_tabs) |
| 223 | + |
| 224 | + handle_extension_class("audio-conversion", config) |
| 225 | + with gr.Tab("Outputs"), gr.Tabs(): |
| 226 | + from tts_webui.history_tab.main import history_tab |
| 227 | + |
| 228 | + collections_directories_atom.render() |
| 229 | + try: |
| 230 | + history_tab() |
| 231 | + history_tab(directory="favorites") |
| 232 | + history_tab( |
| 233 | + directory="outputs", |
| 234 | + show_collections=True, |
| 235 | + ) |
| 236 | + except Exception as e: |
| 237 | + generic_error_tab_advanced(e, name="History", requirements=None) |
| 238 | + |
| 239 | + outputs_tabs = [ |
| 240 | + # voices |
| 241 | + # ("tts_webui.history_tab.voices_tab", "voices_tab", "Voices"), |
| 242 | + ] |
| 243 | + load_tabs(outputs_tabs) |
| 244 | + |
| 245 | + handle_extension_class("outputs", config) |
| 246 | + |
| 247 | + with gr.Tab("Tools"), gr.Tabs(): |
| 248 | + tools_tabs = [] |
| 249 | + load_tabs(tools_tabs) |
| 250 | + |
| 251 | + handle_extension_class("tools", config) |
| 252 | + with gr.Tab("Settings"), gr.Tabs(): |
| 253 | + from tts_webui.settings_tab_gradio import settings_tab_gradio |
| 254 | + |
| 255 | + settings_tab_gradio(reload_config_and_restart_ui, gradio_interface_options) |
| 256 | + |
| 257 | + settings_tabs = [ |
| 258 | + # ( |
| 259 | + # "tts_webui.bark.settings_tab_bark", |
| 260 | + # "settings_tab_bark", |
| 261 | + # "Settings (Bark)", |
| 262 | + # ), |
| 263 | + ( |
| 264 | + "tts_webui.utils.model_location_settings_tab", |
| 265 | + "model_location_settings_tab", |
| 266 | + "Model Location Settings", |
| 267 | + ), |
| 268 | + ("tts_webui.utils.gpu_info_tab", "gpu_info_tab", "GPU Info"), |
| 269 | + ("tts_webui.utils.pip_list_tab", "pip_list_tab", "Installed Packages"), |
| 270 | + ] |
| 271 | + load_tabs(settings_tabs) |
| 272 | + |
| 273 | + extension_list_tab() |
| 274 | + extension_decorator_list_tab() |
| 275 | + |
| 276 | + handle_extension_class("settings", config) |
| 277 | + |
| 278 | + |
| 279 | +def start_gradio_server(): |
| 280 | + def print_pretty_options(options): |
| 281 | + print(" Gradio interface options:") |
| 282 | + max_key_length = max(len(key) for key in options.keys()) |
| 283 | + for key, value in options.items(): |
| 284 | + if key == "auth" and value is not None: |
| 285 | + print(f" {key}:{' ' * (max_key_length - len(key))} {value[0]}:******") |
| 286 | + else: |
| 287 | + print(f" {key}:{' ' * (max_key_length - len(key))} {value}") |
| 288 | + |
| 289 | + # detect if --share is passed |
| 290 | + if "--share" in os.sys.argv: |
| 291 | + print("Gradio share mode enabled") |
| 292 | + gradio_interface_options["share"] = True |
| 293 | + |
| 294 | + if "--docker" in os.sys.argv: |
| 295 | + gradio_interface_options["server_name"] = "0.0.0.0" |
| 296 | + print("Info: Docker mode: opening gradio server on all interfaces") |
| 297 | + |
| 298 | + print("Starting Gradio server...") |
| 299 | + if "enable_queue" in gradio_interface_options: |
| 300 | + del gradio_interface_options["enable_queue"] |
| 301 | + if gradio_interface_options["auth"] is not None: |
| 302 | + # split username:password into (username, password) |
| 303 | + gradio_interface_options["auth"] = tuple( |
| 304 | + gradio_interface_options["auth"].split(":") |
| 305 | + ) |
| 306 | + print("Gradio server authentication enabled") |
| 307 | + # delete show_tips option |
| 308 | + if "show_tips" in gradio_interface_options: |
| 309 | + del gradio_interface_options["show_tips"] |
| 310 | + # TypeError: Blocks.launch() got an unexpected keyword argument 'file_directories' |
| 311 | + if "file_directories" in gradio_interface_options: |
| 312 | + del gradio_interface_options["file_directories"] |
| 313 | + print_pretty_options(gradio_interface_options) |
| 314 | + |
| 315 | + demo = main_ui() |
| 316 | + |
| 317 | + print("\n\n") |
| 318 | + |
| 319 | + if gradio_interface_options["server_name"] == "0.0.0.0": |
| 320 | + print("Notice: Server is open to the internet") |
| 321 | + print( |
| 322 | + f"Gradio server will be available on http://localhost:{gradio_interface_options['server_port']}" |
| 323 | + ) |
| 324 | + |
| 325 | + # concurrency_count=gradio_interface_options.get("concurrency_count", 5), |
| 326 | + demo.queue().launch(**gradio_interface_options, allowed_paths=["."]) |
| 327 | + |
| 328 | + |
| 329 | +def server_hypervisor(): |
| 330 | + import subprocess |
| 331 | + import signal |
| 332 | + import sys |
| 333 | + |
| 334 | + postgres_dir = os.path.join("data", "postgres") |
| 335 | + |
| 336 | + def stop_postgres(postgres_process): |
| 337 | + try: |
| 338 | + subprocess.check_call(f"pg_ctl stop -D {postgres_dir} -m fast", shell=True) |
| 339 | + print("PostgreSQL stopped gracefully.") |
| 340 | + except Exception as e: |
| 341 | + print(f"Error stopping PostgreSQL: {e}") |
| 342 | + postgres_process.terminate() |
| 343 | + |
| 344 | + def signal_handler(signal, frame, postgres_process): |
| 345 | + print("Shutting down...") |
| 346 | + stop_postgres(postgres_process) |
| 347 | + sys.exit(0) |
| 348 | + |
| 349 | + print("Starting React UI...") |
| 350 | + subprocess.Popen( |
| 351 | + "npm start --prefix react-ui", |
| 352 | + env={ |
| 353 | + **os.environ, |
| 354 | + "GRADIO_BACKEND_AUTOMATIC": f"http://127.0.0.1:{gradio_interface_options['server_port']}", |
| 355 | + }, |
| 356 | + shell=True, |
| 357 | + ) |
| 358 | + if "--docker" in os.sys.argv: |
| 359 | + print("Info: Docker mode: skipping Postgres") |
| 360 | + return |
| 361 | + print("Starting Postgres...") |
| 362 | + postgres_process = subprocess.Popen(f"postgres -D {postgres_dir} -p 7773", shell=True) |
| 363 | + try: |
| 364 | + signal.signal( |
| 365 | + signal.SIGINT, |
| 366 | + lambda sig, frame: signal_handler(sig, frame, postgres_process), |
| 367 | + ) # Ctrl+C |
| 368 | + signal.signal( |
| 369 | + signal.SIGTERM, |
| 370 | + lambda sig, frame: signal_handler(sig, frame, postgres_process), |
| 371 | + ) # Termination signals |
| 372 | + if os.name != "nt": |
| 373 | + signal.signal( |
| 374 | + signal.SIGHUP, |
| 375 | + lambda sig, frame: signal_handler(sig, frame, postgres_process), |
| 376 | + ) # Terminal closure |
| 377 | + signal.signal( |
| 378 | + signal.SIGQUIT, |
| 379 | + lambda sig, frame: signal_handler(sig, frame, postgres_process), |
| 380 | + ) # Quit |
| 381 | + except (ValueError, OSError) as e: |
| 382 | + print(f"Failed to set signal handlers: {e}") |
| 383 | + |
| 384 | + |
| 385 | +if __name__ == "__main__": |
| 386 | + server_hypervisor() |
| 387 | + import webbrowser |
| 388 | + |
| 389 | + if gradio_interface_options["inbrowser"]: |
| 390 | + webbrowser.open("http://localhost:3000") |
| 391 | + |
| 392 | + start_gradio_server() |
0 commit comments