Skip to content

Commit 2969059

Browse files
authored
Attempt to infer V2 models (#927)
1 parent 064fecb commit 2969059

File tree

1 file changed

+34
-16
lines changed

1 file changed

+34
-16
lines changed

infer_cli.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
1-
import os, sys, pdb, torch
2-
3-
now_dir = os.getcwd()
4-
sys.path.append(now_dir)
5-
import argparse
1+
from scipy.io import wavfile
2+
from fairseq import checkpoint_utils
3+
from lib.audio import load_audio
4+
from lib.infer_pack.models import (
5+
SynthesizerTrnMs256NSFsid,
6+
SynthesizerTrnMs256NSFsid_nono,
7+
SynthesizerTrnMs768NSFsid,
8+
SynthesizerTrnMs768NSFsid_nono,
9+
)
10+
from vc_infer_pipeline import VC
11+
from multiprocessing import cpu_count
12+
import numpy as np
13+
import torch
14+
import sys
615
import glob
16+
import argparse
17+
import os
718
import sys
19+
import pdb
820
import torch
9-
import numpy as np
10-
from multiprocessing import cpu_count
21+
22+
now_dir = os.getcwd()
23+
sys.path.append(now_dir)
1124

1225
####
1326
# USAGE
@@ -119,11 +132,6 @@ def device_config(self) -> tuple:
119132
config = Config(device, is_half)
120133
now_dir = os.getcwd()
121134
sys.path.append(now_dir)
122-
from vc_infer_pipeline import VC
123-
from infer_pack.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono
124-
from my_utils import load_audio
125-
from fairseq import checkpoint_utils
126-
from scipy.io import wavfile
127135

128136
hubert_model = None
129137

@@ -224,10 +232,20 @@ def get_vc(model_path):
224232
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
225233
if_f0 = cpt.get("f0", 1)
226234
version = cpt.get("version", "v1")
227-
if if_f0 == 1:
228-
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half)
229-
else:
230-
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
235+
if version == "v1":
236+
if if_f0 == 1:
237+
net_g = SynthesizerTrnMs256NSFsid(
238+
*cpt["config"], is_half=is_half
239+
)
240+
else:
241+
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
242+
elif version == "v2":
243+
if if_f0 == 1:
244+
net_g = SynthesizerTrnMs768NSFsid(
245+
*cpt["config"], is_half=is_half
246+
)
247+
else:
248+
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
231249
del net_g.enc_q
232250
print(net_g.load_state_dict(cpt["weight"], strict=False))
233251
net_g.eval().to(device)

0 commit comments

Comments
 (0)