Skip to content

Commit a9a77f2

Browse files
authored
fix-no-f0-model-protect-issue
fix-no-f0-model-protect-issue
1 parent ec0c39d commit a9a77f2

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

infer-web.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
now_dir = os.getcwd()
55
sys.path.append(now_dir)
6-
import traceback
6+
import traceback,pdb
77
import warnings
88

99
import numpy as np
@@ -396,7 +396,7 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
396396

397397

398398
# 一个选项卡全局只能有一个音色
399-
def get_vc(sid):
399+
def get_vc(sid,to_return_protect0,to_return_protect1):
400400
global n_spk, tgt_sr, net_g, vc, cpt, version
401401
if sid == "" or sid == []:
402402
global hubert_model
@@ -434,6 +434,11 @@ def get_vc(sid):
434434
tgt_sr = cpt["config"][-1]
435435
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
436436
if_f0 = cpt.get("f0", 1)
437+
if(if_f0==0):
438+
to_return_protect0=to_return_protect1={"visible": False, "value": 0.5, "__type__": "update"}
439+
else:
440+
to_return_protect0 ={"visible": True, "value": to_return_protect0, "__type__": "update"}
441+
to_return_protect1 ={"visible": True, "value": to_return_protect1, "__type__": "update"}
437442
version = cpt.get("version", "v1")
438443
if version == "v1":
439444
if if_f0 == 1:
@@ -454,7 +459,7 @@ def get_vc(sid):
454459
net_g = net_g.float()
455460
vc = VC(tgt_sr, config)
456461
n_spk = cpt["config"][-3]
457-
return {"visible": True, "maximum": n_spk, "__type__": "update"}
462+
return {"visible": True, "maximum": n_spk, "__type__": "update"},to_return_protect0,to_return_protect1
458463

459464

460465
def change_choices():
@@ -1247,11 +1252,6 @@ def export_onnx(ModelPath, ExportedPath):
12471252
interactive=True,
12481253
)
12491254
clean_button.click(fn=clean, inputs=[], outputs=[sid0])
1250-
sid0.change(
1251-
fn=get_vc,
1252-
inputs=[sid0],
1253-
outputs=[spk_item],
1254-
)
12551255
with gr.Group():
12561256
gr.Markdown(
12571257
value=i18n("男转女推荐+12key, 女转男推荐-12key, 如果音域爆炸导致音色失真也可以自己调整到合适音域. ")
@@ -1475,6 +1475,11 @@ def export_onnx(ModelPath, ExportedPath):
14751475
],
14761476
[vc_output3],
14771477
)
1478+
sid0.change(
1479+
fn=get_vc,
1480+
inputs=[sid0,protect0,protect1],
1481+
outputs=[spk_item,protect0,protect1],
1482+
)
14781483
with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
14791484
with gr.Group():
14801485
gr.Markdown(

vc_infer_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def vc(
184184
with torch.no_grad():
185185
logits = model.extract_features(**inputs)
186186
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
187-
if protect < 0.5:
187+
if protect < 0.5 and pitch!=None and pitchf!=None:
188188
feats0 = feats.clone()
189189
if (
190190
isinstance(index, type(None)) == False
@@ -211,7 +211,7 @@ def vc(
211211
)
212212

213213
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
214-
if protect < 0.5:
214+
if protect < 0.5 and pitch!=None and pitchf!=None:
215215
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
216216
0, 2, 1
217217
)
@@ -223,7 +223,7 @@ def vc(
223223
pitch = pitch[:, :p_len]
224224
pitchf = pitchf[:, :p_len]
225225

226-
if protect < 0.5:
226+
if protect < 0.5 and pitch!=None and pitchf!=None:
227227
pitchff = pitchf.clone()
228228
pitchff[pitchf > 0] = 1
229229
pitchff[pitchf < 1] = protect

0 commit comments

Comments
 (0)