Skip to content

Commit dbf2520

Browse files
committed
add test workflows, fix minor bugs
1 parent 44869c4 commit dbf2520

File tree

1 file changed

+392
-0
lines changed

1 file changed

+392
-0
lines changed

server.py

Lines changed: 392 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
# ruff: noqa: E402
2+
# %%
3+
print("Starting server...\n")
4+
import tts_webui.utils.setup_or_recover as setup_or_recover
5+
6+
setup_or_recover.setup_or_recover()
7+
8+
import tts_webui.utils.dotenv_init as dotenv_init
9+
10+
dotenv_init.init()
11+
12+
import os
13+
import gradio as gr
14+
from tts_webui.utils.suppress_warnings import suppress_warnings
15+
16+
suppress_warnings()
17+
18+
from tts_webui.config.load_config import default_config
19+
from tts_webui.config.config import config
20+
21+
from tts_webui.css.css import full_css
22+
from tts_webui.history_tab.collections_directories_atom import (
23+
collections_directories_atom,
24+
)
25+
26+
27+
from tts_webui.utils.generic_error_tab_advanced import generic_error_tab_advanced
28+
from tts_webui.extensions_loader.interface_extensions import (
29+
extension_list_tab,
30+
handle_extension_class,
31+
)
32+
from tts_webui.extensions_loader.decorator_extensions import (
33+
extension_decorator_list_tab,
34+
)
35+
36+
37+
def reload_config_and_restart_ui():
38+
os._exit(0)
39+
# print("Reloading config and restarting UI...")
40+
# config = load_config()
41+
# gradio_interface_options = config["gradio_interface_options"] if "gradio_interface_options" in config else {}
42+
# demo.close()
43+
# time.sleep(1)
44+
# demo.launch(**gradio_interface_options)
45+
46+
47+
gradio_interface_options = (
48+
config["gradio_interface_options"]
49+
if "gradio_interface_options" in config
50+
else default_config
51+
)
52+
53+
54+
import time
55+
import importlib
56+
57+
58+
def run_tab(module_name, function_name, name, requirements=None):
59+
print(f"Loading {name} tab...")
60+
start_time = time.time()
61+
try:
62+
module = importlib.import_module(module_name)
63+
func = getattr(module, function_name)
64+
func()
65+
except Exception as e:
66+
generic_error_tab_advanced(e, name=name, requirements=requirements)
67+
finally:
68+
elapsed_time = time.time() - start_time
69+
print(f" Done in {elapsed_time:.2f} seconds. ({name})\n")
70+
71+
72+
def load_tabs(list_of_tabs):
73+
for tab in list_of_tabs:
74+
module_name, function_name, name = tab[:3]
75+
requirements = tab[3] if len(tab) > 3 else None
76+
run_tab(module_name, function_name, name, requirements)
77+
78+
79+
def main_ui(theme_choice="Base"):
80+
themes = {
81+
"Base": gr.themes.Base,
82+
"Default": gr.themes.Default,
83+
"Monochrome": gr.themes.Monochrome,
84+
}
85+
theme: gr.themes.Base = themes[theme_choice](
86+
# primary_hue="blue",
87+
primary_hue="sky",
88+
secondary_hue="sky",
89+
neutral_hue="neutral",
90+
font=[
91+
gr.themes.GoogleFont("Inter"),
92+
"ui-sans-serif",
93+
"system-ui",
94+
"sans-serif",
95+
],
96+
)
97+
theme.set(
98+
embed_radius="*radius_sm",
99+
block_label_radius="*radius_sm",
100+
block_label_right_radius="*radius_sm",
101+
block_radius="*radius_sm",
102+
block_title_radius="*radius_sm",
103+
container_radius="*radius_sm",
104+
checkbox_border_radius="*radius_sm",
105+
input_radius="*radius_sm",
106+
table_radius="*radius_sm",
107+
button_large_radius="*radius_sm",
108+
button_small_radius="*radius_sm",
109+
button_primary_background_fill_hover="*primary_300",
110+
button_primary_background_fill_hover_dark="*primary_600",
111+
button_secondary_background_fill_hover="*secondary_200",
112+
button_secondary_background_fill_hover_dark="*secondary_600",
113+
)
114+
115+
with gr.Blocks(
116+
css=full_css,
117+
title="TTS Generation WebUI",
118+
analytics_enabled=False, # it broke too many times
119+
theme=theme,
120+
) as blocks:
121+
gr.Markdown(
122+
"""
123+
# TTS Generation WebUI (Legacy - Gradio) [React UI](http://localhost:3000) [Feedback / Bug reports](https://forms.gle/2L62owhBsGFzdFBC8) [Discord Server](https://discord.gg/V8BKTVRtJ9)
124+
### _(Text To Speech, Audio & Music Generation, Conversion)_
125+
"""
126+
)
127+
with gr.Tabs():
128+
all_tabs()
129+
130+
return blocks
131+
132+
133+
def all_tabs():
134+
with gr.Tab("Text-to-Speech"), gr.Tabs():
135+
tts_tabs = [
136+
("tts_webui.bark.bark_tab", "bark_tab", "Bark TTS"),
137+
(
138+
"tts_webui.bark.clone.tab_voice_clone",
139+
"tab_voice_clone",
140+
"Bark Voice Clone",
141+
"-r requirements_bark_hubert_quantizer.txt",
142+
),
143+
(
144+
"tts_webui.tortoise.tortoise_tab",
145+
"tortoise_tab",
146+
"Tortoise TTS",
147+
),
148+
(
149+
"tts_webui.seamlessM4T.seamless_tab",
150+
"seamless_tab",
151+
"SeamlessM4Tv2Model",
152+
),
153+
(
154+
"tts_webui.vall_e_x.vall_e_x_tab",
155+
"valle_x_tab",
156+
"Valle-X",
157+
"-r requirements_vall_e.txt",
158+
),
159+
("tts_webui.mms.mms_tab", "mms_tab", "MMS"),
160+
(
161+
"tts_webui.maha_tts.maha_tts_tab",
162+
"maha_tts_tab",
163+
"MahaTTS",
164+
"-r requirements_maha_tts.txt",
165+
),
166+
(
167+
"tts_webui.styletts2.styletts2_tab",
168+
"style_tts2_tab",
169+
"StyleTTS2",
170+
"-r requirements_styletts2.txt",
171+
),
172+
]
173+
load_tabs(tts_tabs)
174+
175+
handle_extension_class("text-to-speech", config)
176+
with gr.Tab("Audio/Music Generation"), gr.Tabs():
177+
audio_music_generation_tabs = [
178+
(
179+
"tts_webui.stable_audio.stable_audio_tab",
180+
"stable_audio_tab",
181+
"Stable Audio",
182+
"-r requirements_stable_audio.txt",
183+
),
184+
(
185+
"tts_webui.magnet.magnet_tab",
186+
"magnet_tab",
187+
"MAGNeT",
188+
"-r requirements_audiocraft.txt",
189+
),
190+
(
191+
"tts_webui.musicgen.musicgen_tab",
192+
"musicgen_tab",
193+
"MusicGen",
194+
"-r requirements_audiocraft.txt",
195+
),
196+
]
197+
load_tabs(audio_music_generation_tabs)
198+
199+
handle_extension_class("audio-music-generation", config)
200+
with gr.Tab("Audio Conversion"), gr.Tabs():
201+
audio_conversion_tabs = [
202+
(
203+
"tts_webui.rvc_tab.rvc_tab",
204+
"rvc_conversion_tab",
205+
"RVC",
206+
"-r requirements_rvc.txt",
207+
),
208+
(
209+
"tts_webui.rvc_tab.uvr5_tab",
210+
"uvr5_tab",
211+
"UVR5",
212+
"-r requirements_rvc.txt",
213+
),
214+
(
215+
"tts_webui.demucs.demucs_tab",
216+
"demucs_tab",
217+
"Demucs",
218+
"-r requirements_audiocraft.txt",
219+
),
220+
("tts_webui.vocos.vocos_tabs", "vocos_tabs", "Vocos"),
221+
]
222+
load_tabs(audio_conversion_tabs)
223+
224+
handle_extension_class("audio-conversion", config)
225+
with gr.Tab("Outputs"), gr.Tabs():
226+
from tts_webui.history_tab.main import history_tab
227+
228+
collections_directories_atom.render()
229+
try:
230+
history_tab()
231+
history_tab(directory="favorites")
232+
history_tab(
233+
directory="outputs",
234+
show_collections=True,
235+
)
236+
except Exception as e:
237+
generic_error_tab_advanced(e, name="History", requirements=None)
238+
239+
outputs_tabs = [
240+
# voices
241+
# ("tts_webui.history_tab.voices_tab", "voices_tab", "Voices"),
242+
]
243+
load_tabs(outputs_tabs)
244+
245+
handle_extension_class("outputs", config)
246+
247+
with gr.Tab("Tools"), gr.Tabs():
248+
tools_tabs = []
249+
load_tabs(tools_tabs)
250+
251+
handle_extension_class("tools", config)
252+
with gr.Tab("Settings"), gr.Tabs():
253+
from tts_webui.settings_tab_gradio import settings_tab_gradio
254+
255+
settings_tab_gradio(reload_config_and_restart_ui, gradio_interface_options)
256+
257+
settings_tabs = [
258+
# (
259+
# "tts_webui.bark.settings_tab_bark",
260+
# "settings_tab_bark",
261+
# "Settings (Bark)",
262+
# ),
263+
(
264+
"tts_webui.utils.model_location_settings_tab",
265+
"model_location_settings_tab",
266+
"Model Location Settings",
267+
),
268+
("tts_webui.utils.gpu_info_tab", "gpu_info_tab", "GPU Info"),
269+
("tts_webui.utils.pip_list_tab", "pip_list_tab", "Installed Packages"),
270+
]
271+
load_tabs(settings_tabs)
272+
273+
extension_list_tab()
274+
extension_decorator_list_tab()
275+
276+
handle_extension_class("settings", config)
277+
278+
279+
def start_gradio_server():
280+
def print_pretty_options(options):
281+
print(" Gradio interface options:")
282+
max_key_length = max(len(key) for key in options.keys())
283+
for key, value in options.items():
284+
if key == "auth" and value is not None:
285+
print(f" {key}:{' ' * (max_key_length - len(key))} {value[0]}:******")
286+
else:
287+
print(f" {key}:{' ' * (max_key_length - len(key))} {value}")
288+
289+
# detect if --share is passed
290+
if "--share" in os.sys.argv:
291+
print("Gradio share mode enabled")
292+
gradio_interface_options["share"] = True
293+
294+
if "--docker" in os.sys.argv:
295+
gradio_interface_options["server_name"] = "0.0.0.0"
296+
print("Info: Docker mode: opening gradio server on all interfaces")
297+
298+
print("Starting Gradio server...")
299+
if "enable_queue" in gradio_interface_options:
300+
del gradio_interface_options["enable_queue"]
301+
if gradio_interface_options["auth"] is not None:
302+
# split username:password into (username, password)
303+
gradio_interface_options["auth"] = tuple(
304+
gradio_interface_options["auth"].split(":")
305+
)
306+
print("Gradio server authentication enabled")
307+
# delete show_tips option
308+
if "show_tips" in gradio_interface_options:
309+
del gradio_interface_options["show_tips"]
310+
# TypeError: Blocks.launch() got an unexpected keyword argument 'file_directories'
311+
if "file_directories" in gradio_interface_options:
312+
del gradio_interface_options["file_directories"]
313+
print_pretty_options(gradio_interface_options)
314+
315+
demo = main_ui()
316+
317+
print("\n\n")
318+
319+
if gradio_interface_options["server_name"] == "0.0.0.0":
320+
print("Notice: Server is open to the internet")
321+
print(
322+
f"Gradio server will be available on http://localhost:{gradio_interface_options['server_port']}"
323+
)
324+
325+
# concurrency_count=gradio_interface_options.get("concurrency_count", 5),
326+
demo.queue().launch(**gradio_interface_options, allowed_paths=["."])
327+
328+
329+
def server_hypervisor():
330+
import subprocess
331+
import signal
332+
import sys
333+
334+
postgres_dir = os.path.join("data", "postgres")
335+
336+
def stop_postgres(postgres_process):
337+
try:
338+
subprocess.check_call(f"pg_ctl stop -D {postgres_dir} -m fast", shell=True)
339+
print("PostgreSQL stopped gracefully.")
340+
except Exception as e:
341+
print(f"Error stopping PostgreSQL: {e}")
342+
postgres_process.terminate()
343+
344+
def signal_handler(signal, frame, postgres_process):
345+
print("Shutting down...")
346+
stop_postgres(postgres_process)
347+
sys.exit(0)
348+
349+
print("Starting React UI...")
350+
subprocess.Popen(
351+
"npm start --prefix react-ui",
352+
env={
353+
**os.environ,
354+
"GRADIO_BACKEND_AUTOMATIC": f"http://127.0.0.1:{gradio_interface_options['server_port']}",
355+
},
356+
shell=True,
357+
)
358+
if "--docker" in os.sys.argv:
359+
print("Info: Docker mode: skipping Postgres")
360+
return
361+
print("Starting Postgres...")
362+
postgres_process = subprocess.Popen(f"postgres -D {postgres_dir} -p 7773", shell=True)
363+
try:
364+
signal.signal(
365+
signal.SIGINT,
366+
lambda sig, frame: signal_handler(sig, frame, postgres_process),
367+
) # Ctrl+C
368+
signal.signal(
369+
signal.SIGTERM,
370+
lambda sig, frame: signal_handler(sig, frame, postgres_process),
371+
) # Termination signals
372+
if os.name != "nt":
373+
signal.signal(
374+
signal.SIGHUP,
375+
lambda sig, frame: signal_handler(sig, frame, postgres_process),
376+
) # Terminal closure
377+
signal.signal(
378+
signal.SIGQUIT,
379+
lambda sig, frame: signal_handler(sig, frame, postgres_process),
380+
) # Quit
381+
except (ValueError, OSError) as e:
382+
print(f"Failed to set signal handlers: {e}")
383+
384+
385+
if __name__ == "__main__":
386+
server_hypervisor()
387+
import webbrowser
388+
389+
if gradio_interface_options["inbrowser"]:
390+
webbrowser.open("http://localhost:3000")
391+
392+
start_gradio_server()

0 commit comments

Comments
 (0)