|
| 1 | +''' |
| 2 | +0416后的更新: |
| 3 | + 引入config中half |
| 4 | + 重建npy而不用填写 |
| 5 | + v2支持 |
| 6 | + 无f0模型支持 |
| 7 | + 修复 |
| 8 | +
|
| 9 | + int16: |
| 10 | + 增加无索引支持 |
| 11 | + f0算法改harvest(怎么看就只有这个会影响CPU占用),但是不这么改效果不好 |
| 12 | +''' |
1 | 13 | import os, sys, traceback
|
2 |
| - |
3 | 14 | now_dir = os.getcwd()
|
4 | 15 | sys.path.append(now_dir)
|
| 16 | +from config import Config |
| 17 | +is_half=Config().is_half |
5 | 18 | import PySimpleGUI as sg
|
6 | 19 | import sounddevice as sd
|
7 | 20 | import noisereduce as nr
|
|
13 | 26 | import scipy.signal as signal
|
14 | 27 |
|
15 | 28 | # import matplotlib.pyplot as plt
|
16 |
| -from infer_pack.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono |
| 29 | +from infer_pack.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono,SynthesizerTrnMs768NSFsid,SynthesizerTrnMs768NSFsid_nono |
17 | 30 | from i18n import I18nAuto
|
18 | 31 |
|
19 | 32 | i18n = I18nAuto()
|
@@ -50,20 +63,33 @@ def __init__(
|
50 | 63 | )
|
51 | 64 | self.model = models[0]
|
52 | 65 | self.model = self.model.to(device)
|
53 |
| - self.model = self.model.half() |
| 66 | + if(is_half==True): |
| 67 | + self.model = self.model.half() |
| 68 | + else: |
| 69 | + self.model = self.model.float() |
54 | 70 | self.model.eval()
|
55 | 71 | cpt = torch.load(pth_path, map_location="cpu")
|
56 | 72 | self.tgt_sr = cpt["config"][-1]
|
57 | 73 | cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
|
58 | 74 | self.if_f0 = cpt.get("f0", 1)
|
59 |
| - if self.if_f0 == 1: |
60 |
| - self.net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=True) |
61 |
| - else: |
62 |
| - self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) |
| 75 | + self.version = cpt.get("version", "v1") |
| 76 | + if version == "v1": |
| 77 | + if if_f0 == 1: |
| 78 | + net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half) |
| 79 | + else: |
| 80 | + net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) |
| 81 | + elif version == "v2": |
| 82 | + if if_f0 == 1: |
| 83 | + net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half) |
| 84 | + else: |
| 85 | + net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) |
63 | 86 | del self.net_g.enc_q
|
64 | 87 | print(self.net_g.load_state_dict(cpt["weight"], strict=False))
|
65 | 88 | self.net_g.eval().to(device)
|
66 |
| - self.net_g.half() |
| 89 | + if(is_half==True): |
| 90 | + self.net_g=self.net_g.half() |
| 91 | + else: |
| 92 | + self.net_g=self.net_g.float() |
67 | 93 | except:
|
68 | 94 | print(traceback.format_exc())
|
69 | 95 |
|
@@ -116,34 +142,33 @@ def infer(self, feats: torch.Tensor) -> np.ndarray:
|
116 | 142 | inputs = {
|
117 | 143 | "source": feats.half().to(device),
|
118 | 144 | "padding_mask": padding_mask.to(device),
|
119 |
| - "output_layer": 9, # layer 9 |
| 145 | + "output_layer": 9 if self.version == "v1" else 12, |
120 | 146 | }
|
121 | 147 | torch.cuda.synchronize()
|
122 | 148 | with torch.no_grad():
|
123 | 149 | logits = self.model.extract_features(**inputs)
|
124 |
| - feats = self.model.final_proj(logits[0]) |
| 150 | + feats = model.final_proj(logits[0]) if self.version == "v1" else logits[0] |
125 | 151 |
|
126 | 152 | ####索引优化
|
127 |
| - if hasattr(self, "index") and hasattr(self, "big_npy") and self.index_rate != 0: |
128 |
| - npy = feats[0].cpu().numpy().astype("float32") |
129 |
| - |
130 |
| - # _, I = self.index.search(npy, 1) |
131 |
| - # npy = self.big_npy[I.squeeze()].astype("float16") |
132 |
| - |
133 |
| - score, ix = self.index.search(npy, k=8) |
134 |
| - weight = np.square(1 / score) |
135 |
| - weight /= weight.sum(axis=1, keepdims=True) |
136 |
| - npy = np.sum( |
137 |
| - self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1 |
138 |
| - ).astype("float16") |
139 |
| - |
140 |
| - feats = ( |
141 |
| - torch.from_numpy(npy).unsqueeze(0).to(device) * self.index_rate |
142 |
| - + (1 - self.index_rate) * feats |
143 |
| - ) |
144 |
| - else: |
145 |
| - print("index search FAIL or disabled") |
146 |
| - |
| 153 | + try: |
| 154 | + if hasattr(self, "index") and hasattr(self, "big_npy") and self.index_rate != 0: |
| 155 | + npy = feats[0].cpu().numpy().astype("float32") |
| 156 | + score, ix = self.index.search(npy, k=8) |
| 157 | + weight = np.square(1 / score) |
| 158 | + weight /= weight.sum(axis=1, keepdims=True) |
| 159 | + npy = np.sum( |
| 160 | + self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1 |
| 161 | + ) |
| 162 | + if(is_half==True):npy=npy.astype("float16") |
| 163 | + feats = ( |
| 164 | + torch.from_numpy(npy).unsqueeze(0).to(device) * self.index_rate |
| 165 | + + (1 - self.index_rate) * feats |
| 166 | + ) |
| 167 | + else: |
| 168 | + print("index search FAIL or disabled") |
| 169 | + except: |
| 170 | + traceback.print_exc() |
| 171 | + print("index search FAIL") |
147 | 172 | feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
148 | 173 | torch.cuda.synchronize()
|
149 | 174 | print(feats.shape)
|
|
0 commit comments