Skip to content

Commit 36456e3

Browse files
authored
replace warn (#1255)
1 parent 1d86fb7 commit 36456e3

File tree

8 files changed

+321
-14
lines changed

8 files changed

+321
-14
lines changed

infer-web.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,14 +389,14 @@ def get_pretrained_models(path_str, f0_str, sr2):
389389
"assets/pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2), os.F_OK
390390
)
391391
if not if_pretrained_generator_exist:
392-
logger.warn(
392+
logger.warning(
393393
"assets/pretrained%s/%sG%s.pth not exist, will not use pretrained model",
394394
path_str,
395395
f0_str,
396396
sr2,
397397
)
398398
if not if_pretrained_discriminator_exist:
399-
logger.warn(
399+
logger.warning(
400400
"assets/pretrained%s/%sD%s.pth not exist, will not use pretrained model",
401401
path_str,
402402
f0_str,

infer/lib/train/data_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def get_audio(self, filename):
113113
try:
114114
spec = torch.load(spec_filename)
115115
except:
116-
logger.warn("%s %s", spec_filename, traceback.format_exc())
116+
logger.warning("%s %s", spec_filename, traceback.format_exc())
117117
spec = spectrogram_torch(
118118
audio_norm,
119119
self.filter_length,
@@ -305,7 +305,7 @@ def get_audio(self, filename):
305305
try:
306306
spec = torch.load(spec_filename)
307307
except:
308-
logger.warn("%s %s", spec_filename, traceback.format_exc())
308+
logger.warning("%s %s", spec_filename, traceback.format_exc())
309309
spec = spectrogram_torch(
310310
audio_norm,
311311
self.filter_length,

infer/lib/train/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def go(model, bkey):
3333
try:
3434
new_state_dict[k] = saved_state_dict[k]
3535
if saved_state_dict[k].shape != state_dict[k].shape:
36-
logger.warn(
36+
logger.warning(
3737
"shape-%s-mismatch. need: %s, get: %s",
3838
k,
3939
state_dict[k].shape,
@@ -111,7 +111,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
111111
try:
112112
new_state_dict[k] = saved_state_dict[k]
113113
if saved_state_dict[k].shape != state_dict[k].shape:
114-
logger.warn(
114+
logger.warning(
115115
"shape-%s-mismatch|need-%s|get-%s",
116116
k,
117117
state_dict[k].shape,
@@ -409,7 +409,7 @@ def get_hparams_from_file(config_path):
409409
def check_git_hash(model_dir):
410410
source_dir = os.path.dirname(os.path.realpath(__file__))
411411
if not os.path.exists(os.path.join(source_dir, ".git")):
412-
logger.warn(
412+
logger.warning(
413413
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
414414
source_dir
415415
)
@@ -422,7 +422,7 @@ def check_git_hash(model_dir):
422422
if os.path.exists(path):
423423
saved_hash = open(path).read()
424424
if saved_hash != cur_hash:
425-
logger.warn(
425+
logger.warning(
426426
"git hash values are different. {}(saved) != {}(current)".format(
427427
saved_hash[:8], cur_hash[:8]
428428
)

infer/modules/train/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def main():
9999
n_gpus = 1
100100
if n_gpus < 1:
101101
# patch to unblock people without gpus. there is probably a better way.
102-
logger.warn("NO GPU DETECTED: falling back to CPU - this may take a while")
102+
logger.warning("NO GPU DETECTED: falling back to CPU - this may take a while")
103103
n_gpus = 1
104104
os.environ["MASTER_ADDR"] = "localhost"
105105
os.environ["MASTER_PORT"] = str(randint(20000, 55555))

infer/modules/vc/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def vc_single(
224224
)
225225
except:
226226
info = traceback.format_exc()
227-
logger.warn(info)
227+
logger.warning(info)
228228
return info, (None, None)
229229

230230
def vc_multi(

modules.py

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
import traceback
2+
import logging
3+
4+
logger = logging.getLogger(__name__)
5+
6+
import numpy as np
7+
import soundfile as sf
8+
import torch
9+
from io import BytesIO
10+
11+
from infer.lib.audio import load_audio, wav2
12+
from infer.lib.infer_pack.models import (
13+
SynthesizerTrnMs256NSFsid,
14+
SynthesizerTrnMs256NSFsid_nono,
15+
SynthesizerTrnMs768NSFsid,
16+
SynthesizerTrnMs768NSFsid_nono,
17+
)
18+
from infer.modules.vc.pipeline import Pipeline
19+
from infer.modules.vc.utils import *
20+
21+
22+
class VC:
23+
def __init__(self, config):
24+
self.n_spk = None
25+
self.tgt_sr = None
26+
self.net_g = None
27+
self.pipeline = None
28+
self.cpt = None
29+
self.version = None
30+
self.if_f0 = None
31+
self.version = None
32+
self.hubert_model = None
33+
34+
self.config = config
35+
36+
def get_vc(self, sid, *to_return_protect):
37+
logger.info("Get sid: " + sid)
38+
39+
to_return_protect0 = {
40+
"visible": self.if_f0 != 0,
41+
"value": to_return_protect[0]
42+
if self.if_f0 != 0 and to_return_protect
43+
else 0.5,
44+
"__type__": "update",
45+
}
46+
to_return_protect1 = {
47+
"visible": self.if_f0 != 0,
48+
"value": to_return_protect[1]
49+
if self.if_f0 != 0 and to_return_protect
50+
else 0.33,
51+
"__type__": "update",
52+
}
53+
54+
if sid == "" or sid == []:
55+
if self.hubert_model is not None: # 考虑到轮询, 需要加个判断看是否 sid 是由有模型切换到无模型的
56+
logger.info("Clean model cache")
57+
del (
58+
self.net_g,
59+
self.n_spk,
60+
self.vc,
61+
self.hubert_model,
62+
self.tgt_sr,
63+
) # ,cpt
64+
self.hubert_model = (
65+
self.net_g
66+
) = self.n_spk = self.vc = self.hubert_model = self.tgt_sr = None
67+
if torch.cuda.is_available():
68+
torch.cuda.empty_cache()
69+
###楼下不这么折腾清理不干净
70+
self.if_f0 = self.cpt.get("f0", 1)
71+
self.version = self.cpt.get("version", "v1")
72+
if self.version == "v1":
73+
if self.if_f0 == 1:
74+
self.net_g = SynthesizerTrnMs256NSFsid(
75+
*self.cpt["config"], is_half=self.config.is_half
76+
)
77+
else:
78+
self.net_g = SynthesizerTrnMs256NSFsid_nono(*self.cpt["config"])
79+
elif self.version == "v2":
80+
if self.if_f0 == 1:
81+
self.net_g = SynthesizerTrnMs768NSFsid(
82+
*self.cpt["config"], is_half=self.config.is_half
83+
)
84+
else:
85+
self.net_g = SynthesizerTrnMs768NSFsid_nono(*self.cpt["config"])
86+
del self.net_g, self.cpt
87+
if torch.cuda.is_available():
88+
torch.cuda.empty_cache()
89+
return (
90+
{"visible": False, "__type__": "update"},
91+
{
92+
"visible": True,
93+
"value": to_return_protect0,
94+
"__type__": "update",
95+
},
96+
{
97+
"visible": True,
98+
"value": to_return_protect1,
99+
"__type__": "update",
100+
},
101+
"",
102+
"",
103+
)
104+
person = f'{os.getenv("weight_root")}/{sid}'
105+
logger.info(f"Loading: {person}")
106+
107+
self.cpt = torch.load(person, map_location="cpu")
108+
self.tgt_sr = self.cpt["config"][-1]
109+
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0] # n_spk
110+
self.if_f0 = self.cpt.get("f0", 1)
111+
self.version = self.cpt.get("version", "v1")
112+
113+
synthesizer_class = {
114+
("v1", 1): SynthesizerTrnMs256NSFsid,
115+
("v1", 0): SynthesizerTrnMs256NSFsid_nono,
116+
("v2", 1): SynthesizerTrnMs768NSFsid,
117+
("v2", 0): SynthesizerTrnMs768NSFsid_nono,
118+
}
119+
120+
self.net_g = synthesizer_class.get(
121+
(self.version, self.if_f0), SynthesizerTrnMs256NSFsid
122+
)(*self.cpt["config"], is_half=self.config.is_half)
123+
124+
del self.net_g.enc_q
125+
126+
self.net_g.load_state_dict(self.cpt["weight"], strict=False)
127+
self.net_g.eval().to(self.config.device)
128+
if self.config.is_half:
129+
self.net_g = self.net_g.half()
130+
else:
131+
self.net_g = self.net_g.float()
132+
133+
self.pipeline = Pipeline(self.tgt_sr, self.config)
134+
n_spk = self.cpt["config"][-3]
135+
index = {"value": get_index_path_from_model(sid), "__type__": "update"}
136+
logger.info("Select index: " + index["value"])
137+
138+
return (
139+
(
140+
{"visible": True, "maximum": n_spk, "__type__": "update"},
141+
to_return_protect0,
142+
to_return_protect1,
143+
index,
144+
index,
145+
)
146+
if to_return_protect
147+
else {"visible": True, "maximum": n_spk, "__type__": "update"}
148+
)
149+
150+
def vc_single(
151+
self,
152+
sid,
153+
input_audio_path,
154+
f0_up_key,
155+
f0_file,
156+
f0_method,
157+
file_index,
158+
file_index2,
159+
index_rate,
160+
filter_radius,
161+
resample_sr,
162+
rms_mix_rate,
163+
protect,
164+
):
165+
if input_audio_path is None:
166+
return "You need to upload an audio", None
167+
f0_up_key = int(f0_up_key)
168+
try:
169+
audio = load_audio(input_audio_path, 16000)
170+
audio_max = np.abs(audio).max() / 0.95
171+
if audio_max > 1:
172+
audio /= audio_max
173+
times = [0, 0, 0]
174+
175+
if self.hubert_model is None:
176+
self.hubert_model = load_hubert(self.config)
177+
178+
file_index = (
179+
(
180+
file_index.strip(" ")
181+
.strip('"')
182+
.strip("\n")
183+
.strip('"')
184+
.strip(" ")
185+
.replace("trained", "added")
186+
)
187+
if file_index != ""
188+
else file_index2
189+
) # 防止小白写错,自动帮他替换掉
190+
191+
audio_opt = self.pipeline.pipeline(
192+
self.hubert_model,
193+
self.net_g,
194+
sid,
195+
audio,
196+
input_audio_path,
197+
times,
198+
f0_up_key,
199+
f0_method,
200+
file_index,
201+
index_rate,
202+
self.if_f0,
203+
filter_radius,
204+
self.tgt_sr,
205+
resample_sr,
206+
rms_mix_rate,
207+
self.version,
208+
protect,
209+
f0_file,
210+
)
211+
if self.tgt_sr != resample_sr >= 16000:
212+
tgt_sr = resample_sr
213+
else:
214+
tgt_sr = self.tgt_sr
215+
index_info = (
216+
"Index:\n%s." % file_index
217+
if os.path.exists(file_index)
218+
else "Index not used."
219+
)
220+
return (
221+
"Success.\n%s\nTime:\nnpy: %.2fs, f0: %.2fs, infer: %.2fs."
222+
% (index_info, *times),
223+
(tgt_sr, audio_opt),
224+
)
225+
except:
226+
info = traceback.format_exc()
227+
logger.warning(info)
228+
return info, (None, None)
229+
230+
def vc_multi(
231+
self,
232+
sid,
233+
dir_path,
234+
opt_root,
235+
paths,
236+
f0_up_key,
237+
f0_method,
238+
file_index,
239+
file_index2,
240+
index_rate,
241+
filter_radius,
242+
resample_sr,
243+
rms_mix_rate,
244+
protect,
245+
format1,
246+
):
247+
try:
248+
dir_path = (
249+
dir_path.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
250+
) # 防止小白拷路径头尾带了空格和"和回车
251+
opt_root = opt_root.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
252+
os.makedirs(opt_root, exist_ok=True)
253+
try:
254+
if dir_path != "":
255+
paths = [
256+
os.path.join(dir_path, name) for name in os.listdir(dir_path)
257+
]
258+
else:
259+
paths = [path.name for path in paths]
260+
except:
261+
traceback.print_exc()
262+
paths = [path.name for path in paths]
263+
infos = []
264+
for path in paths:
265+
info, opt = self.vc_single(
266+
sid,
267+
path,
268+
f0_up_key,
269+
None,
270+
f0_method,
271+
file_index,
272+
file_index2,
273+
# file_big_npy,
274+
index_rate,
275+
filter_radius,
276+
resample_sr,
277+
rms_mix_rate,
278+
protect,
279+
)
280+
if "Success" in info:
281+
try:
282+
tgt_sr, audio_opt = opt
283+
if format1 in ["wav", "flac"]:
284+
sf.write(
285+
"%s/%s.%s"
286+
% (opt_root, os.path.basename(path), format1),
287+
audio_opt,
288+
tgt_sr,
289+
)
290+
else:
291+
path = "%s/%s.%s" % (
292+
opt_root,
293+
os.path.basename(path),
294+
format1,
295+
)
296+
with BytesIO() as wavf:
297+
sf.write(wavf, audio_opt, tgt_sr, format="wav")
298+
wavf.seek(0, 0)
299+
with open(path, "wb") as outf:
300+
wav2(wavf, outf, format1)
301+
except:
302+
info += traceback.format_exc()
303+
infos.append("%s->%s" % (os.path.basename(path), info))
304+
yield "\n".join(infos)
305+
yield "\n".join(infos)
306+
except:
307+
yield traceback.format_exc()

0 commit comments

Comments
 (0)