Skip to content

Commit 615c30c

Browse files
authored
Update gui.py
1 parent 79a79c3 commit 615c30c

File tree

1 file changed

+55
-30
lines changed

1 file changed

+55
-30
lines changed

gui.py

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
1+
'''
2+
0416后的更新:
3+
引入config中half
4+
重建npy而不用填写
5+
v2支持
6+
无f0模型支持
7+
修复
8+
9+
int16:
10+
增加无索引支持
11+
f0算法改harvest(怎么看就只有这个会影响CPU占用),但是不这么改效果不好
12+
'''
113
import os, sys, traceback
2-
314
now_dir = os.getcwd()
415
sys.path.append(now_dir)
16+
from config import Config
17+
is_half=Config().is_half
518
import PySimpleGUI as sg
619
import sounddevice as sd
720
import noisereduce as nr
@@ -13,7 +26,7 @@
1326
import scipy.signal as signal
1427

1528
# import matplotlib.pyplot as plt
16-
from infer_pack.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono
29+
from infer_pack.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono,SynthesizerTrnMs768NSFsid,SynthesizerTrnMs768NSFsid_nono
1730
from i18n import I18nAuto
1831

1932
i18n = I18nAuto()
@@ -50,20 +63,33 @@ def __init__(
5063
)
5164
self.model = models[0]
5265
self.model = self.model.to(device)
53-
self.model = self.model.half()
66+
if(is_half==True):
67+
self.model = self.model.half()
68+
else:
69+
self.model = self.model.float()
5470
self.model.eval()
5571
cpt = torch.load(pth_path, map_location="cpu")
5672
self.tgt_sr = cpt["config"][-1]
5773
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
5874
self.if_f0 = cpt.get("f0", 1)
59-
if self.if_f0 == 1:
60-
self.net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=True)
61-
else:
62-
self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
75+
self.version = cpt.get("version", "v1")
76+
if version == "v1":
77+
if if_f0 == 1:
78+
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
79+
else:
80+
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
81+
elif version == "v2":
82+
if if_f0 == 1:
83+
net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
84+
else:
85+
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
6386
del self.net_g.enc_q
6487
print(self.net_g.load_state_dict(cpt["weight"], strict=False))
6588
self.net_g.eval().to(device)
66-
self.net_g.half()
89+
if(is_half==True):
90+
self.net_g=self.net_g.half()
91+
else:
92+
self.net_g=self.net_g.float()
6793
except:
6894
print(traceback.format_exc())
6995

@@ -116,34 +142,33 @@ def infer(self, feats: torch.Tensor) -> np.ndarray:
116142
inputs = {
117143
"source": feats.half().to(device),
118144
"padding_mask": padding_mask.to(device),
119-
"output_layer": 9, # layer 9
145+
"output_layer": 9 if self.version == "v1" else 12,
120146
}
121147
torch.cuda.synchronize()
122148
with torch.no_grad():
123149
logits = self.model.extract_features(**inputs)
124-
feats = self.model.final_proj(logits[0])
150+
feats = model.final_proj(logits[0]) if self.version == "v1" else logits[0]
125151

126152
####索引优化
127-
if hasattr(self, "index") and hasattr(self, "big_npy") and self.index_rate != 0:
128-
npy = feats[0].cpu().numpy().astype("float32")
129-
130-
# _, I = self.index.search(npy, 1)
131-
# npy = self.big_npy[I.squeeze()].astype("float16")
132-
133-
score, ix = self.index.search(npy, k=8)
134-
weight = np.square(1 / score)
135-
weight /= weight.sum(axis=1, keepdims=True)
136-
npy = np.sum(
137-
self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1
138-
).astype("float16")
139-
140-
feats = (
141-
torch.from_numpy(npy).unsqueeze(0).to(device) * self.index_rate
142-
+ (1 - self.index_rate) * feats
143-
)
144-
else:
145-
print("index search FAIL or disabled")
146-
153+
try:
154+
if hasattr(self, "index") and hasattr(self, "big_npy") and self.index_rate != 0:
155+
npy = feats[0].cpu().numpy().astype("float32")
156+
score, ix = self.index.search(npy, k=8)
157+
weight = np.square(1 / score)
158+
weight /= weight.sum(axis=1, keepdims=True)
159+
npy = np.sum(
160+
self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1
161+
)
162+
if(is_half==True):npy=npy.astype("float16")
163+
feats = (
164+
torch.from_numpy(npy).unsqueeze(0).to(device) * self.index_rate
165+
+ (1 - self.index_rate) * feats
166+
)
167+
else:
168+
print("index search FAIL or disabled")
169+
except:
170+
traceback.print_exc()
171+
print("index search FAIL")
147172
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
148173
torch.cuda.synchronize()
149174
print(feats.shape)

0 commit comments

Comments
 (0)