forked from index-tts/index-tts
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwebui.py
More file actions
212 lines (190 loc) · 10.1 KB
/
webui.py
File metadata and controls
212 lines (190 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import json
import os
import sys
import threading
import time
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
sys.path.append(os.path.join(current_dir, "indextts"))
import argparse
parser = argparse.ArgumentParser(description="IndexTTS WebUI")
parser.add_argument("--verbose", action="store_true", default=False, help="Enable verbose mode")
parser.add_argument("--port", type=int, default=7860, help="Port to run the web UI on")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Host to run the web UI on")
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Model checkpoints directory")
cmd_args = parser.parse_args()
if not os.path.exists(cmd_args.model_dir):
print(f"Model directory {cmd_args.model_dir} does not exist. Please download the model first.")
sys.exit(1)
for file in [
"bigvgan_generator.pth",
"bpe.model",
"gpt.pth",
"config.yaml",
]:
file_path = os.path.join(cmd_args.model_dir, file)
if not os.path.exists(file_path):
print(f"Required file {file_path} does not exist. Please download it.")
sys.exit(1)
import gradio as gr
from indextts.infer import IndexTTS
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto(language="zh_CN")
MODE = 'local'
tts = IndexTTS(model_dir=cmd_args.model_dir, cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),)
os.makedirs("outputs/tasks",exist_ok=True)
os.makedirs("prompts",exist_ok=True)
with open("tests/cases.jsonl", "r", encoding="utf-8") as f:
example_cases = []
for line in f:
line = line.strip()
if not line:
continue
example = json.loads(line)
example_cases.append([os.path.join("tests", example.get("prompt_audio", "sample_prompt.wav")),
example.get("text"), ["普通推理", "批次推理"][example.get("infer_mode", 0)]])
def gen_single(prompt, text, infer_mode, max_text_tokens_per_sentence=120, sentences_bucket_max_size=4,
*args, progress=gr.Progress()):
output_path = None
if not output_path:
output_path = os.path.join("outputs", f"spk_{int(time.time())}.wav")
# set gradio progress
tts.gr_progress = progress
do_sample, top_p, top_k, temperature, \
length_penalty, num_beams, repetition_penalty, max_mel_tokens = args
kwargs = {
"do_sample": bool(do_sample),
"top_p": float(top_p),
"top_k": int(top_k) if int(top_k) > 0 else None,
"temperature": float(temperature),
"length_penalty": float(length_penalty),
"num_beams": num_beams,
"repetition_penalty": float(repetition_penalty),
"max_mel_tokens": int(max_mel_tokens),
# "typical_sampling": bool(typical_sampling),
# "typical_mass": float(typical_mass),
}
if infer_mode == "普通推理":
output = tts.infer(prompt, text, output_path, verbose=cmd_args.verbose,
max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
**kwargs)
else:
# 批次推理
output = tts.infer_fast(prompt, text, output_path, verbose=cmd_args.verbose,
max_text_tokens_per_sentence=int(max_text_tokens_per_sentence),
sentences_bucket_max_size=(sentences_bucket_max_size),
**kwargs)
return gr.update(value=output,visible=True)
def update_prompt_audio():
update_button = gr.update(interactive=True)
return update_button
with gr.Blocks(title="IndexTTS Demo") as demo:
mutex = threading.Lock()
gr.HTML('''
<h2><center>IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System</h2>
<h2><center>(一款工业级可控且高效的零样本文本转语音系统)</h2>
<p align="center">
<a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
</p>
''')
with gr.Tab("音频生成"):
with gr.Row():
os.makedirs("prompts",exist_ok=True)
prompt_audio = gr.Audio(label="参考音频",key="prompt_audio",
sources=["upload","microphone"],type="filepath")
prompt_list = os.listdir("prompts")
default = ''
if prompt_list:
default = prompt_list[0]
with gr.Column():
input_text_single = gr.TextArea(label="文本",key="input_text_single", placeholder="请输入目标文本", info="当前模型版本{}".format(tts.model_version or "1.0"))
infer_mode = gr.Radio(choices=["普通推理", "批次推理"], label="推理模式",info="批次推理:更适合长句,性能翻倍",value="普通推理")
gen_button = gr.Button("生成语音", key="gen_button",interactive=True)
output_audio = gr.Audio(label="生成结果", visible=True,key="output_audio")
with gr.Accordion("高级生成参数设置", open=False):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**GPT2 采样设置** _参数会影响音频多样性和生成速度详见[Generation strategies](https://huggingface.co/docs/transformers/main/en/generation_strategies)_")
with gr.Row():
do_sample = gr.Checkbox(label="do_sample", value=True, info="是否进行采样")
temperature = gr.Slider(label="temperature", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
with gr.Row():
top_p = gr.Slider(label="top_p", minimum=0.0, maximum=1.0, value=0.8, step=0.01)
top_k = gr.Slider(label="top_k", minimum=0, maximum=100, value=30, step=1)
num_beams = gr.Slider(label="num_beams", value=3, minimum=1, maximum=10, step=1)
with gr.Row():
repetition_penalty = gr.Number(label="repetition_penalty", precision=None, value=10.0, minimum=0.1, maximum=20.0, step=0.1)
length_penalty = gr.Number(label="length_penalty", precision=None, value=0.0, minimum=-2.0, maximum=2.0, step=0.1)
max_mel_tokens = gr.Slider(label="max_mel_tokens", value=600, minimum=50, maximum=tts.cfg.gpt.max_mel_tokens, step=10, info="生成Token最大数量,过小导致音频被截断", key="max_mel_tokens")
# with gr.Row():
# typical_sampling = gr.Checkbox(label="typical_sampling", value=False, info="不建议使用")
# typical_mass = gr.Slider(label="typical_mass", value=0.9, minimum=0.0, maximum=1.0, step=0.1)
with gr.Column(scale=2):
gr.Markdown("**分句设置** _参数会影响音频质量和生成速度_")
with gr.Row():
max_text_tokens_per_sentence = gr.Slider(
label="分句最大Token数", value=120, minimum=20, maximum=tts.cfg.gpt.max_text_tokens, step=2, key="max_text_tokens_per_sentence",
info="建议80~200之间,值越大,分句越长;值越小,分句越碎;过小过大都可能导致音频质量不高",
)
sentences_bucket_max_size = gr.Slider(
label="分句分桶的最大容量(批次推理生效)", value=4, minimum=1, maximum=16, step=1, key="sentences_bucket_max_size",
info="建议2-8之间,值越大,一批次推理包含的分句数越多,过大可能导致内存溢出",
)
with gr.Accordion("预览分句结果", open=True) as sentences_settings:
sentences_preview = gr.Dataframe(
headers=["序号", "分句内容", "Token数"],
key="sentences_preview",
wrap=True,
)
advanced_params = [
do_sample, top_p, top_k, temperature,
length_penalty, num_beams, repetition_penalty, max_mel_tokens,
# typical_sampling, typical_mass,
]
if len(example_cases) > 0:
gr.Examples(
examples=example_cases,
inputs=[prompt_audio, input_text_single, infer_mode],
)
def on_input_text_change(text, max_tokens_per_sentence):
if text and len(text) > 0:
text_tokens_list = tts.tokenizer.tokenize(text)
sentences = tts.tokenizer.split_sentences(text_tokens_list, max_tokens_per_sentence=int(max_tokens_per_sentence))
data = []
for i, s in enumerate(sentences):
sentence_str = ''.join(s)
tokens_count = len(s)
data.append([i, sentence_str, tokens_count])
return {
sentences_preview: gr.update(value=data, visible=True, type="array"),
}
else:
df = pd.DataFrame([], columns=["序号", "分句内容", "Token数"])
return {
sentences_preview: gr.update(value=df)
}
input_text_single.change(
on_input_text_change,
inputs=[input_text_single, max_text_tokens_per_sentence],
outputs=[sentences_preview]
)
max_text_tokens_per_sentence.change(
on_input_text_change,
inputs=[input_text_single, max_text_tokens_per_sentence],
outputs=[sentences_preview]
)
prompt_audio.upload(update_prompt_audio,
inputs=[],
outputs=[gen_button])
gen_button.click(gen_single,
inputs=[prompt_audio, input_text_single, infer_mode,
max_text_tokens_per_sentence, sentences_bucket_max_size,
*advanced_params,
],
outputs=[output_audio])
if __name__ == "__main__":
demo.queue(20)
demo.launch(server_name=cmd_args.host, server_port=cmd_args.port)