Skip to content

Commit 4de5d0d

Browse files
更新批量推理脚本,可以不用 webui (#518)
* 增加批量推理脚本 * 更新批量推理脚本
1 parent c74727d commit 4de5d0d

File tree

1 file changed

+181
-0
lines changed

1 file changed

+181
-0
lines changed

infer_batch_rvc.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
'''
2+
v1
3+
runtime\python.exe myinfer-v2-0528.py 0 "E:\codes\py39\RVC-beta\todo-songs" "E:\codes\py39\logs\mi-test\added_IVF677_Flat_nprobe_7.index" harvest "E:\codes\py39\RVC-beta\output" "E:\codes\py39\test-20230416b\weights\mi-test.pth" 0.66 cuda:0 True 3 0 1 0.33
4+
v2
5+
runtime\python.exe myinfer-v2-0528.py 0 "E:\codes\py39\RVC-beta\todo-songs" "E:\codes\py39\test-20230416b\logs\mi-test-v2\aadded_IVF677_Flat_nprobe_1_v2.index" harvest "E:\codes\py39\RVC-beta\output_v2" "E:\codes\py39\test-20230416b\weights\mi-test-v2.pth" 0.66 cuda:0 True 3 0 1 0.33
6+
'''
7+
import os,sys,pdb,torch
8+
now_dir = os.getcwd()
9+
sys.path.append(now_dir)
10+
import argparse
11+
import glob
12+
import sys
13+
import torch
14+
import tqdm as tq
15+
from multiprocessing import cpu_count
16+
class Config:
17+
def __init__(self,device,is_half):
18+
self.device = device
19+
self.is_half = is_half
20+
self.n_cpu = 0
21+
self.gpu_name = None
22+
self.gpu_mem = None
23+
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
24+
25+
def device_config(self) -> tuple:
26+
if torch.cuda.is_available():
27+
i_device = int(self.device.split(":")[-1])
28+
self.gpu_name = torch.cuda.get_device_name(i_device)
29+
if (
30+
("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
31+
or "P40" in self.gpu_name.upper()
32+
or "1060" in self.gpu_name
33+
or "1070" in self.gpu_name
34+
or "1080" in self.gpu_name
35+
):
36+
print("16系/10系显卡和P40强制单精度")
37+
self.is_half = False
38+
for config_file in ["32k.json", "40k.json", "48k.json"]:
39+
with open(f"configs/{config_file}", "r") as f:
40+
strr = f.read().replace("true", "false")
41+
with open(f"configs/{config_file}", "w") as f:
42+
f.write(strr)
43+
with open("trainset_preprocess_pipeline_print.py", "r") as f:
44+
strr = f.read().replace("3.7", "3.0")
45+
with open("trainset_preprocess_pipeline_print.py", "w") as f:
46+
f.write(strr)
47+
else:
48+
self.gpu_name = None
49+
self.gpu_mem = int(
50+
torch.cuda.get_device_properties(i_device).total_memory
51+
/ 1024
52+
/ 1024
53+
/ 1024
54+
+ 0.4
55+
)
56+
if self.gpu_mem <= 4:
57+
with open("trainset_preprocess_pipeline_print.py", "r") as f:
58+
strr = f.read().replace("3.7", "3.0")
59+
with open("trainset_preprocess_pipeline_print.py", "w") as f:
60+
f.write(strr)
61+
elif torch.backends.mps.is_available():
62+
print("没有发现支持的N卡, 使用MPS进行推理")
63+
self.device = "mps"
64+
else:
65+
print("没有发现支持的N卡, 使用CPU进行推理")
66+
self.device = "cpu"
67+
self.is_half = True
68+
69+
if self.n_cpu == 0:
70+
self.n_cpu = cpu_count()
71+
72+
if self.is_half:
73+
# 6G显存配置
74+
x_pad = 3
75+
x_query = 10
76+
x_center = 60
77+
x_max = 65
78+
else:
79+
# 5G显存配置
80+
x_pad = 1
81+
x_query = 6
82+
x_center = 38
83+
x_max = 41
84+
85+
if self.gpu_mem != None and self.gpu_mem <= 4:
86+
x_pad = 1
87+
x_query = 5
88+
x_center = 30
89+
x_max = 32
90+
91+
return x_pad, x_query, x_center, x_max
92+
93+
f0up_key=sys.argv[1]
94+
input_path=sys.argv[2]
95+
index_path=sys.argv[3]
96+
f0method=sys.argv[4]#harvest or pm
97+
opt_path=sys.argv[5]
98+
model_path=sys.argv[6]
99+
index_rate=float(sys.argv[7])
100+
device=sys.argv[8]
101+
is_half=bool(sys.argv[9])
102+
filter_radius=int(sys.argv[10])
103+
resample_sr=int(sys.argv[11])
104+
rms_mix_rate=float(sys.argv[12])
105+
protect=float(sys.argv[13])
106+
print(sys.argv)
107+
config=Config(device,is_half)
108+
now_dir=os.getcwd()
109+
sys.path.append(now_dir)
110+
from vc_infer_pipeline import VC
111+
from infer_pack.models import (
112+
SynthesizerTrnMs256NSFsid,
113+
SynthesizerTrnMs256NSFsid_nono,
114+
SynthesizerTrnMs768NSFsid,
115+
SynthesizerTrnMs768NSFsid_nono,
116+
)
117+
from my_utils import load_audio
118+
from fairseq import checkpoint_utils
119+
from scipy.io import wavfile
120+
121+
hubert_model=None
122+
def load_hubert():
123+
global hubert_model
124+
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(["hubert_base.pt"],suffix="",)
125+
hubert_model = models[0]
126+
hubert_model = hubert_model.to(device)
127+
if(is_half):hubert_model = hubert_model.half()
128+
else:hubert_model = hubert_model.float()
129+
hubert_model.eval()
130+
131+
def vc_single(sid,input_audio,f0_up_key,f0_file,f0_method,file_index,index_rate):
132+
global tgt_sr,net_g,vc,hubert_model,version
133+
if input_audio is None:return "You need to upload an audio", None
134+
f0_up_key = int(f0_up_key)
135+
audio=load_audio(input_audio,16000)
136+
times = [0, 0, 0]
137+
if(hubert_model==None):load_hubert()
138+
if_f0 = cpt.get("f0", 1)
139+
# audio_opt=vc.pipeline(hubert_model,net_g,sid,audio,times,f0_up_key,f0_method,file_index,file_big_npy,index_rate,if_f0,f0_file=f0_file)
140+
audio_opt=vc.pipeline(hubert_model,net_g,sid,audio,input_audio,times,f0_up_key,f0_method,file_index,index_rate,if_f0,filter_radius,tgt_sr,resample_sr,rms_mix_rate,version,protect,f0_file=f0_file)
141+
print(times)
142+
return audio_opt
143+
144+
145+
def get_vc(model_path):
146+
global n_spk,tgt_sr,net_g,vc,cpt,device,is_half,version
147+
print("loading pth %s"%model_path)
148+
cpt = torch.load(model_path, map_location="cpu")
149+
tgt_sr = cpt["config"][-1]
150+
cpt["config"][-3]=cpt["weight"]["emb_g.weight"].shape[0]#n_spk
151+
if_f0=cpt.get("f0",1)
152+
version = cpt.get("version", "v1")
153+
if version == "v1":
154+
if if_f0 == 1:
155+
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half)
156+
else:
157+
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
158+
elif version == "v2":
159+
if if_f0 == 1:#
160+
net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half)
161+
else:
162+
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
163+
del net_g.enc_q
164+
print(net_g.load_state_dict(cpt["weight"], strict=False)) # 不加这一行清不干净,真奇葩
165+
net_g.eval().to(device)
166+
if (is_half):net_g = net_g.half()
167+
else:net_g = net_g.float()
168+
vc = VC(tgt_sr, config)
169+
n_spk=cpt["config"][-3]
170+
# return {"visible": True,"maximum": n_spk, "__type__": "update"}
171+
172+
173+
get_vc(model_path)
174+
audios = os.listdir(input_path)
175+
for file in tq.tqdm(audios):
176+
if file.endswith('.wav'):
177+
file_path = input_path + '/' + file
178+
wav_opt=vc_single(0,file_path,f0up_key,None,f0method,index_path,index_rate)
179+
out_path = opt_path + '/' + file
180+
wavfile.write(out_path, tgt_sr, wav_opt)
181+

0 commit comments

Comments
 (0)