Skip to content

Commit 72a18e6

Browse files
authored
Add files via upload
1 parent 43c4f43 commit 72a18e6

File tree

2 files changed

+765
-0
lines changed

2 files changed

+765
-0
lines changed

modules.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
import traceback
2+
import logging
3+
4+
logger = logging.getLogger(__name__)
5+
6+
import numpy as np
7+
import soundfile as sf
8+
import torch
9+
from io import BytesIO
10+
11+
from infer.lib.audio import load_audio, wav2
12+
from infer.lib.infer_pack.models import (
13+
SynthesizerTrnMs256NSFsid,
14+
SynthesizerTrnMs256NSFsid_nono,
15+
SynthesizerTrnMs768NSFsid,
16+
SynthesizerTrnMs768NSFsid_nono,
17+
)
18+
from infer.modules.vc.pipeline import Pipeline
19+
from infer.modules.vc.utils import *
20+
21+
22+
class VC:
23+
def __init__(self, config):
24+
self.n_spk = None
25+
self.tgt_sr = None
26+
self.net_g = None
27+
self.pipeline = None
28+
self.cpt = None
29+
self.version = None
30+
self.if_f0 = None
31+
self.version = None
32+
self.hubert_model = None
33+
34+
self.config = config
35+
36+
def get_vc(self, sid, *to_return_protect):
37+
logger.info("Get sid: " + sid)
38+
39+
to_return_protect0 = {
40+
"visible": self.if_f0 != 0,
41+
"value": to_return_protect[0]
42+
if self.if_f0 != 0 and to_return_protect
43+
else 0.5,
44+
"__type__": "update",
45+
}
46+
to_return_protect1 = {
47+
"visible": self.if_f0 != 0,
48+
"value": to_return_protect[1]
49+
if self.if_f0 != 0 and to_return_protect
50+
else 0.33,
51+
"__type__": "update",
52+
}
53+
54+
if sid == "" or sid == []:
55+
if self.hubert_model is not None: # 考虑到轮询, 需要加个判断看是否 sid 是由有模型切换到无模型的
56+
logger.info("Clean model cache")
57+
del (
58+
self.net_g,
59+
self.n_spk,
60+
self.vc,
61+
self.hubert_model,
62+
self.tgt_sr,
63+
) # ,cpt
64+
self.hubert_model = (
65+
self.net_g
66+
) = self.n_spk = self.vc = self.hubert_model = self.tgt_sr = None
67+
if torch.cuda.is_available():
68+
torch.cuda.empty_cache()
69+
###楼下不这么折腾清理不干净
70+
self.if_f0 = self.cpt.get("f0", 1)
71+
self.version = self.cpt.get("version", "v1")
72+
if self.version == "v1":
73+
if self.if_f0 == 1:
74+
self.net_g = SynthesizerTrnMs256NSFsid(
75+
*self.cpt["config"], is_half=self.config.is_half
76+
)
77+
else:
78+
self.net_g = SynthesizerTrnMs256NSFsid_nono(*self.cpt["config"])
79+
elif self.version == "v2":
80+
if self.if_f0 == 1:
81+
self.net_g = SynthesizerTrnMs768NSFsid(
82+
*self.cpt["config"], is_half=self.config.is_half
83+
)
84+
else:
85+
self.net_g = SynthesizerTrnMs768NSFsid_nono(*self.cpt["config"])
86+
del self.net_g, self.cpt
87+
if torch.cuda.is_available():
88+
torch.cuda.empty_cache()
89+
return (
90+
{"visible": False, "__type__": "update"},
91+
{
92+
"visible": True,
93+
"value": to_return_protect0,
94+
"__type__": "update",
95+
},
96+
{
97+
"visible": True,
98+
"value": to_return_protect1,
99+
"__type__": "update",
100+
},
101+
"",
102+
"",
103+
)
104+
person = f'{os.getenv("weight_root")}/{sid}'
105+
logger.info(f"Loading: {person}")
106+
107+
self.cpt = torch.load(person, map_location="cpu")
108+
self.tgt_sr = self.cpt["config"][-1]
109+
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0] # n_spk
110+
self.if_f0 = self.cpt.get("f0", 1)
111+
self.version = self.cpt.get("version", "v1")
112+
113+
synthesizer_class = {
114+
("v1", 1): SynthesizerTrnMs256NSFsid,
115+
("v1", 0): SynthesizerTrnMs256NSFsid_nono,
116+
("v2", 1): SynthesizerTrnMs768NSFsid,
117+
("v2", 0): SynthesizerTrnMs768NSFsid_nono,
118+
}
119+
120+
self.net_g = synthesizer_class.get(
121+
(self.version, self.if_f0), SynthesizerTrnMs256NSFsid
122+
)(*self.cpt["config"], is_half=self.config.is_half)
123+
124+
del self.net_g.enc_q
125+
126+
self.net_g.load_state_dict(self.cpt["weight"], strict=False)
127+
self.net_g.eval().to(self.config.device)
128+
if self.config.is_half:
129+
self.net_g = self.net_g.half()
130+
else:
131+
self.net_g = self.net_g.float()
132+
133+
self.pipeline = Pipeline(self.tgt_sr, self.config)
134+
n_spk = self.cpt["config"][-3]
135+
index = {"value": get_index_path_from_model(sid), "__type__": "update"}
136+
logger.info("Select index: " + index["value"])
137+
138+
return (
139+
(
140+
{"visible": True, "maximum": n_spk, "__type__": "update"},
141+
to_return_protect0,
142+
to_return_protect1,
143+
index,
144+
index,
145+
)
146+
if to_return_protect
147+
else {"visible": True, "maximum": n_spk, "__type__": "update"}
148+
)
149+
150+
def vc_single(
151+
self,
152+
sid,
153+
input_audio_path,
154+
f0_up_key,
155+
f0_file,
156+
f0_method,
157+
file_index,
158+
file_index2,
159+
index_rate,
160+
filter_radius,
161+
resample_sr,
162+
rms_mix_rate,
163+
protect,
164+
):
165+
if input_audio_path is None:
166+
return "You need to upload an audio", None
167+
f0_up_key = int(f0_up_key)
168+
try:
169+
audio = load_audio(input_audio_path, 16000)
170+
audio_max = np.abs(audio).max() / 0.95
171+
if audio_max > 1:
172+
audio /= audio_max
173+
times = [0, 0, 0]
174+
175+
if self.hubert_model is None:
176+
self.hubert_model = load_hubert(self.config)
177+
178+
file_index = (
179+
(
180+
file_index.strip(" ")
181+
.strip('"')
182+
.strip("\n")
183+
.strip('"')
184+
.strip(" ")
185+
.replace("trained", "added")
186+
)
187+
if file_index != ""
188+
else file_index2
189+
) # 防止小白写错,自动帮他替换掉
190+
191+
audio_opt = self.pipeline.pipeline(
192+
self.hubert_model,
193+
self.net_g,
194+
sid,
195+
audio,
196+
input_audio_path,
197+
times,
198+
f0_up_key,
199+
f0_method,
200+
file_index,
201+
index_rate,
202+
self.if_f0,
203+
filter_radius,
204+
self.tgt_sr,
205+
resample_sr,
206+
rms_mix_rate,
207+
self.version,
208+
protect,
209+
f0_file,
210+
)
211+
if self.tgt_sr != resample_sr >= 16000:
212+
tgt_sr = resample_sr
213+
else:
214+
tgt_sr = self.tgt_sr
215+
index_info = (
216+
"Index:\n%s." % file_index
217+
if os.path.exists(file_index)
218+
else "Index not used."
219+
)
220+
return (
221+
"Success.\n%s\nTime:\nnpy: %.2fs, f0: %.2fs, infer: %.2fs."
222+
% (index_info, *times),
223+
(tgt_sr, audio_opt),
224+
)
225+
except:
226+
info = traceback.format_exc()
227+
logger.warn(info)
228+
return info, (None, None)
229+
230+
def vc_multi(
231+
self,
232+
sid,
233+
dir_path,
234+
opt_root,
235+
paths,
236+
f0_up_key,
237+
f0_method,
238+
file_index,
239+
file_index2,
240+
index_rate,
241+
filter_radius,
242+
resample_sr,
243+
rms_mix_rate,
244+
protect,
245+
format1,
246+
):
247+
try:
248+
dir_path = (
249+
dir_path.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
250+
) # 防止小白拷路径头尾带了空格和"和回车
251+
opt_root = opt_root.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
252+
os.makedirs(opt_root, exist_ok=True)
253+
try:
254+
if dir_path != "":
255+
paths = [
256+
os.path.join(dir_path, name) for name in os.listdir(dir_path)
257+
]
258+
else:
259+
paths = [path.name for path in paths]
260+
except:
261+
traceback.print_exc()
262+
paths = [path.name for path in paths]
263+
infos = []
264+
for path in paths:
265+
info, opt = self.vc_single(
266+
sid,
267+
path,
268+
f0_up_key,
269+
None,
270+
f0_method,
271+
file_index,
272+
file_index2,
273+
# file_big_npy,
274+
index_rate,
275+
filter_radius,
276+
resample_sr,
277+
rms_mix_rate,
278+
protect,
279+
)
280+
if "Success" in info:
281+
try:
282+
tgt_sr, audio_opt = opt
283+
if format1 in ["wav", "flac"]:
284+
sf.write(
285+
"%s/%s.%s"
286+
% (opt_root, os.path.basename(path), format1),
287+
audio_opt,
288+
tgt_sr,
289+
)
290+
else:
291+
path = "%s/%s.%s" % (opt_root, os.path.basename(path), format1)
292+
with BytesIO() as wavf:
293+
sf.write(
294+
wavf,
295+
audio_opt,
296+
tgt_sr,
297+
format="wav"
298+
)
299+
wavf.seek(0, 0)
300+
with open(path, "wb") as outf:
301+
wav2(wavf, outf, format1)
302+
except:
303+
info += traceback.format_exc()
304+
infos.append("%s->%s" % (os.path.basename(path), info))
305+
yield "\n".join(infos)
306+
yield "\n".join(infos)
307+
except:
308+
yield traceback.format_exc()

0 commit comments

Comments
 (0)