Skip to content

Commit a6cb4d3

Browse files
authored
support 16xx GPU and 4G GPU inference
support 16xx GPU and 4G GPU inference
1 parent 2ac8d55 commit a6cb4d3

File tree

2 files changed

+57
-30
lines changed

2 files changed

+57
-30
lines changed

config.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,25 @@ def has_mps() -> bool:
6464
device = "cpu"
6565
is_half = False
6666

67+
gpu_mem=None
6768
if device not in ["cpu", "mps"]:
68-
gpu_name = torch.cuda.get_device_name(int(device.split(":")[-1]))
69-
if "16" in gpu_name or "MX" in gpu_name:
70-
print("16系显卡/MX系显卡强制单精度")
69+
i_device=int(device.split(":")[-1])
70+
gpu_name = torch.cuda.get_device_name(i_device)
71+
if "16" in gpu_name or "P40"in gpu_name.upper() or "1070"in gpu_name or "1080"in gpu_name:
72+
print("16系显卡强制单精度")
7173
is_half = False
72-
74+
with open("configs/32k.json","r")as f:strr=f.read().replace("true","false")
75+
with open("configs/32k.json","w")as f:f.write(strr)
76+
with open("configs/40k.json","r")as f:strr=f.read().replace("true","false")
77+
with open("configs/40k.json","w")as f:f.write(strr)
78+
with open("configs/48k.json","r")as f:strr=f.read().replace("true","false")
79+
with open("configs/48k.json","w")as f:f.write(strr)
80+
with open("trainset_preprocess_pipeline_print.py","r")as f:strr=f.read().replace("3.7","3.0")
81+
with open("trainset_preprocess_pipeline_print.py","w")as f:f.write(strr)
82+
gpu_mem=int(torch.cuda.get_device_properties(i_device).total_memory/1024/1024/1024+0.4)
83+
if(gpu_mem<=4):
84+
with open("trainset_preprocess_pipeline_print.py","r")as f:strr=f.read().replace("3.7","3.0")
85+
with open("trainset_preprocess_pipeline_print.py","w")as f:f.write(strr)
7386
from multiprocessing import cpu_count
7487

7588
if n_cpu == 0:
@@ -86,3 +99,8 @@ def has_mps() -> bool:
8699
x_query = 6
87100
x_center = 38
88101
x_max = 41
102+
if(gpu_mem!=None and gpu_mem<=4):
103+
x_pad = 1
104+
x_query = 5
105+
x_center = 30
106+
x_max = 32

infer-web.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from time import sleep
66
import torch, os, traceback, sys, warnings, shutil, numpy as np
77
import faiss
8-
8+
from random import shuffle
99
now_dir = os.getcwd()
1010
sys.path.append(now_dir)
1111
tmp = os.path.join(now_dir, "TEMP")
@@ -23,6 +23,7 @@
2323
ncpu = cpu_count()
2424
ngpu = torch.cuda.device_count()
2525
gpu_infos = []
26+
mem=[]
2627
if (not torch.cuda.is_available()) or ngpu == 0:
2728
if_gpu_ok = False
2829
else:
@@ -48,11 +49,13 @@
4849
): # A10#A100#V100#A40#P40#M40#K80#A4500
4950
if_gpu_ok = True # 至少有一张能用的N卡
5051
gpu_infos.append("%s\t%s" % (i, gpu_name))
51-
gpu_info = (
52-
"\n".join(gpu_infos)
53-
if if_gpu_ok == True and len(gpu_infos) > 0
54-
else "很遗憾您这没有能用的显卡来支持您训练"
55-
)
52+
mem.append(int(torch.cuda.get_device_properties(i).total_memory/1024/1024/1024+0.4))
53+
if if_gpu_ok == True and len(gpu_infos) > 0:
54+
gpu_info ="\n".join(gpu_infos)
55+
default_batch_size=min(mem)//2
56+
else:
57+
gpu_info = "很遗憾您这没有能用的显卡来支持您训练"
58+
default_batch_size=1
5659
gpus = "-".join([i[0] for i in gpu_infos])
5760
from infer_pack.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono
5861
from scipy.io import wavfile
@@ -564,15 +567,18 @@ def click_train(
564567
)
565568
)
566569
if if_f0_3 == "是":
567-
opt.append(
568-
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature256/mute.npy|%s/logs/mute/2a_f0/mute.wav.npy|%s/logs/mute/2b-f0nsf/mute.wav.npy|%s"
569-
% (now_dir, sr2, now_dir, now_dir, now_dir, spk_id5)
570-
)
570+
for _ in range(2):
571+
opt.append(
572+
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature256/mute.npy|%s/logs/mute/2a_f0/mute.wav.npy|%s/logs/mute/2b-f0nsf/mute.wav.npy|%s"
573+
% (now_dir, sr2, now_dir, now_dir, now_dir, spk_id5)
574+
)
571575
else:
572-
opt.append(
573-
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature256/mute.npy|%s"
574-
% (now_dir, sr2, now_dir, spk_id5)
575-
)
576+
for _ in range(2):
577+
opt.append(
578+
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature256/mute.npy|%s"
579+
% (now_dir, sr2, now_dir, spk_id5)
580+
)
581+
shuffle(opt)
576582
with open("%s/filelist.txt" % exp_dir, "w") as f:
577583
f.write("\n".join(opt))
578584
print("write filelist done")
@@ -789,15 +795,18 @@ def get_info_str(strr):
789795
)
790796
)
791797
if if_f0_3 == "是":
792-
opt.append(
793-
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature256/mute.npy|%s/logs/mute/2a_f0/mute.wav.npy|%s/logs/mute/2b-f0nsf/mute.wav.npy|%s"
794-
% (now_dir, sr2, now_dir, now_dir, now_dir, spk_id5)
795-
)
798+
for _ in range(2):
799+
opt.append(
800+
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature256/mute.npy|%s/logs/mute/2a_f0/mute.wav.npy|%s/logs/mute/2b-f0nsf/mute.wav.npy|%s"
801+
% (now_dir, sr2, now_dir, now_dir, now_dir, spk_id5)
802+
)
796803
else:
797-
opt.append(
798-
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature256/mute.npy|%s"
799-
% (now_dir, sr2, now_dir, spk_id5)
800-
)
804+
for _ in range(2):
805+
opt.append(
806+
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature256/mute.npy|%s"
807+
% (now_dir, sr2, now_dir, spk_id5)
808+
)
809+
shuffle(opt)
801810
with open("%s/filelist.txt" % exp_dir, "w") as f:
802811
f.write("\n".join(opt))
803812
yield get_info_str("write filelist done")
@@ -1039,7 +1048,7 @@ def export_onnx(ModelPath, ExportedPath, MoeVS=True):
10391048
minimum=0,
10401049
maximum=1,
10411050
label="检索特征占比",
1042-
value=0.65,
1051+
value=0.76,
10431052
interactive=True,
10441053
)
10451054
f0_file = gr.File(label=i18n("F0曲线文件, 可选, 一行一个音高, 代替默认F0及升降调"))
@@ -1253,10 +1262,10 @@ def export_onnx(ModelPath, ExportedPath, MoeVS=True):
12531262
)
12541263
batch_size12 = gr.Slider(
12551264
minimum=0,
1256-
maximum=32,
1265+
maximum=40,
12571266
step=1,
12581267
label="每张显卡的batch_size",
1259-
value=4,
1268+
value=default_batch_size,
12601269
interactive=True,
12611270
)
12621271
if_save_latest13 = gr.Radio(
@@ -1270,7 +1279,7 @@ def export_onnx(ModelPath, ExportedPath, MoeVS=True):
12701279
"是否缓存所有训练集至显存. 10min以下小数据可缓存以加速训练, 大数据缓存会炸显存也加不了多少速"
12711280
),
12721281
choices=["是", "否"],
1273-
value="",
1282+
value="",
12741283
interactive=True,
12751284
)
12761285
with gr.Row():

0 commit comments

Comments
 (0)