Skip to content

Commit 9253948

Browse files
authored
Rewrite syntax of infer-web.py (#536)
* Fix import location * use any * Correction of if Syntax * Class definitions to the front * format * fix if Syntax
1 parent c5758a8 commit 9253948

File tree

1 file changed

+148
-92
lines changed

1 file changed

+148
-92
lines changed

infer-web.py

Lines changed: 148 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,42 @@
1-
import torch, os, traceback, sys, warnings, shutil, numpy as np
1+
import os
2+
import shutil
3+
import sys
4+
import traceback
5+
import warnings
6+
7+
import numpy as np
8+
import torch
29

310
os.environ["no_proxy"] = "localhost, 127.0.0.1, ::1"
11+
import logging
412
import threading
5-
from time import sleep
13+
from random import shuffle
614
from subprocess import Popen
15+
from time import sleep
16+
717
import faiss
8-
from random import shuffle
18+
import ffmpeg
19+
import gradio as gr
20+
import soundfile as sf
21+
from config import Config
22+
from fairseq import checkpoint_utils
23+
from i18n import I18nAuto
24+
from infer_pack.models import (
25+
SynthesizerTrnMs256NSFsid,
26+
SynthesizerTrnMs256NSFsid_nono,
27+
SynthesizerTrnMs768NSFsid,
28+
SynthesizerTrnMs768NSFsid_nono,
29+
)
30+
from infer_pack.models_onnx import SynthesizerTrnMsNSFsidM
31+
from infer_uvr5 import _audio_pre_, _audio_pre_new
32+
from MDXNet import MDXNetDereverb
33+
from my_utils import load_audio
34+
from train.process_ckpt import change_info, extract_small_model, merge, show_info
35+
from vc_infer_pipeline import VC
36+
37+
# from trainset_preprocess_pipeline import PreProcess
38+
39+
logging.getLogger("numba").setLevel(logging.WARNING)
940

1041
now_dir = os.getcwd()
1142
sys.path.append(now_dir)
@@ -19,41 +50,43 @@
1950
os.environ["TEMP"] = tmp
2051
warnings.filterwarnings("ignore")
2152
torch.manual_seed(114514)
22-
from i18n import I18nAuto
23-
import ffmpeg
24-
from MDXNet import MDXNetDereverb
2553

54+
55+
config = Config()
2656
i18n = I18nAuto()
2757
i18n.print()
2858
# 判断是否有能用来训练和加速推理的N卡
2959
ngpu = torch.cuda.device_count()
3060
gpu_infos = []
3161
mem = []
32-
if (not torch.cuda.is_available()) or ngpu == 0:
33-
if_gpu_ok = False
34-
else:
35-
if_gpu_ok = False
62+
if_gpu_ok = False
63+
64+
if torch.cuda.is_available() or ngpu != 0:
3665
for i in range(ngpu):
3766
gpu_name = torch.cuda.get_device_name(i)
38-
if (
39-
"10" in gpu_name
40-
or "16" in gpu_name
41-
or "20" in gpu_name
42-
or "30" in gpu_name
43-
or "40" in gpu_name
44-
or "A2" in gpu_name.upper()
45-
or "A3" in gpu_name.upper()
46-
or "A4" in gpu_name.upper()
47-
or "P4" in gpu_name.upper()
48-
or "A50" in gpu_name.upper()
49-
or "A60" in gpu_name.upper()
50-
or "70" in gpu_name
51-
or "80" in gpu_name
52-
or "90" in gpu_name
53-
or "M4" in gpu_name.upper()
54-
or "T4" in gpu_name.upper()
55-
or "TITAN" in gpu_name.upper()
56-
): # A10#A100#V100#A40#P40#M40#K80#A4500
67+
if any(
68+
value in gpu_name.upper()
69+
for value in [
70+
"10",
71+
"16",
72+
"20",
73+
"30",
74+
"40",
75+
"A2",
76+
"A3",
77+
"A4",
78+
"P4",
79+
"A50",
80+
"A60",
81+
"70",
82+
"80",
83+
"90",
84+
"M4",
85+
"T4",
86+
"TITAN",
87+
]
88+
):
89+
# A10#A100#V100#A40#P40#M40#K80#A4500
5790
if_gpu_ok = True # 至少有一张能用的N卡
5891
gpu_infos.append("%s\t%s" % (i, gpu_name))
5992
mem.append(
@@ -65,32 +98,13 @@
6598
+ 0.4
6699
)
67100
)
68-
if if_gpu_ok == True and len(gpu_infos) > 0:
101+
if if_gpu_ok and len(gpu_infos) > 0:
69102
gpu_info = "\n".join(gpu_infos)
70103
default_batch_size = min(mem) // 2
71104
else:
72105
gpu_info = i18n("很遗憾您这没有能用的显卡来支持您训练")
73106
default_batch_size = 1
74107
gpus = "-".join([i[0] for i in gpu_infos])
75-
from infer_pack.models import (
76-
SynthesizerTrnMs256NSFsid,
77-
SynthesizerTrnMs256NSFsid_nono,
78-
SynthesizerTrnMs768NSFsid,
79-
SynthesizerTrnMs768NSFsid_nono,
80-
)
81-
import soundfile as sf
82-
from fairseq import checkpoint_utils
83-
import gradio as gr
84-
import logging
85-
from vc_infer_pipeline import VC
86-
from config import Config
87-
from infer_uvr5 import _audio_pre_, _audio_pre_new
88-
from my_utils import load_audio
89-
from train.process_ckpt import show_info, change_info, merge, extract_small_model
90-
91-
config = Config()
92-
# from trainset_preprocess_pipeline import PreProcess
93-
logging.getLogger("numba").setLevel(logging.WARNING)
94108

95109

96110
class ToolButton(gr.Button, gr.components.FormComponent):
@@ -164,7 +178,7 @@ def vc_single(
164178
if audio_max > 1:
165179
audio /= audio_max
166180
times = [0, 0, 0]
167-
if hubert_model == None:
181+
if not hubert_model:
168182
load_hubert()
169183
if_f0 = cpt.get("f0", 1)
170184
file_index = (
@@ -203,7 +217,7 @@ def vc_single(
203217
protect,
204218
f0_file=f0_file,
205219
)
206-
if resample_sr >= 16000 and tgt_sr != resample_sr:
220+
if tgt_sr != resample_sr >= 16000:
207221
tgt_sr = resample_sr
208222
index_info = (
209223
"Using index:%s." % file_index
@@ -385,7 +399,7 @@ def get_vc(sid):
385399
global n_spk, tgt_sr, net_g, vc, cpt, version
386400
if sid == "" or sid == []:
387401
global hubert_model
388-
if hubert_model != None: # 考虑到轮询, 需要加个判断看是否 sid 是由有模型切换到无模型的
402+
if hubert_model is not None: # 考虑到轮询, 需要加个判断看是否 sid 是由有模型切换到无模型的
389403
print("clean_empty_cache")
390404
del net_g, n_spk, vc, hubert_model, tgt_sr # ,cpt
391405
hubert_model = net_g = n_spk = vc = hubert_model = tgt_sr = None
@@ -471,7 +485,7 @@ def clean():
471485

472486
def if_done(done, p):
473487
while 1:
474-
if p.poll() == None:
488+
if p.poll() is None:
475489
sleep(0.5)
476490
else:
477491
break
@@ -484,7 +498,7 @@ def if_done_multi(done, ps):
484498
# 只要有一个进程未结束都不停
485499
flag = 1
486500
for p in ps:
487-
if p.poll() == None:
501+
if p.poll() is None:
488502
flag = 0
489503
sleep(0.5)
490504
break
@@ -519,7 +533,7 @@ def preprocess_dataset(trainset_dir, exp_dir, sr, n_p):
519533
with open("%s/logs/%s/preprocess.log" % (now_dir, exp_dir), "r") as f:
520534
yield (f.read())
521535
sleep(1)
522-
if done[0] == True:
536+
if done[0]:
523537
break
524538
with open("%s/logs/%s/preprocess.log" % (now_dir, exp_dir), "r") as f:
525539
log = f.read()
@@ -557,7 +571,7 @@ def extract_f0_feature(gpus, n_p, f0method, if_f0, exp_dir, version19):
557571
) as f:
558572
yield (f.read())
559573
sleep(1)
560-
if done[0] == True:
574+
if done[0]:
561575
break
562576
with open("%s/logs/%s/extract_f0_feature.log" % (now_dir, exp_dir), "r") as f:
563577
log = f.read()
@@ -605,7 +619,7 @@ def extract_f0_feature(gpus, n_p, f0method, if_f0, exp_dir, version19):
605619
with open("%s/logs/%s/extract_f0_feature.log" % (now_dir, exp_dir), "r") as f:
606620
yield (f.read())
607621
sleep(1)
608-
if done[0] == True:
622+
if done[0]:
609623
break
610624
with open("%s/logs/%s/extract_f0_feature.log" % (now_dir, exp_dir), "r") as f:
611625
log = f.read()
@@ -616,51 +630,98 @@ def extract_f0_feature(gpus, n_p, f0method, if_f0, exp_dir, version19):
616630
def change_sr2(sr2, if_f0_3, version19):
617631
path_str = "" if version19 == "v1" else "_v2"
618632
f0_str = "f0" if if_f0_3 else ""
619-
if_pretrained_generator_exist = os.access("pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2), os.F_OK)
620-
if_pretrained_discriminator_exist = os.access("pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2), os.F_OK)
621-
if (if_pretrained_generator_exist == False):
622-
print("pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2), "not exist, will not use pretrained model")
623-
if (if_pretrained_discriminator_exist == False):
624-
print("pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2), "not exist, will not use pretrained model")
633+
if_pretrained_generator_exist = os.access(
634+
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2), os.F_OK
635+
)
636+
if_pretrained_discriminator_exist = os.access(
637+
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2), os.F_OK
638+
)
639+
if if_pretrained_generator_exist is not False:
640+
print(
641+
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2),
642+
"not exist, will not use pretrained model",
643+
)
644+
if if_pretrained_discriminator_exist is not False:
645+
print(
646+
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2),
647+
"not exist, will not use pretrained model",
648+
)
625649
return (
626-
("pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2)) if if_pretrained_generator_exist else "",
627-
("pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2)) if if_pretrained_discriminator_exist else "",
628-
{"visible": True, "__type__": "update"}
650+
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2)
651+
if if_pretrained_generator_exist
652+
else "",
653+
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2)
654+
if if_pretrained_discriminator_exist
655+
else "",
656+
{"visible": True, "__type__": "update"},
629657
)
630658

659+
631660
def change_version19(sr2, if_f0_3, version19):
632661
path_str = "" if version19 == "v1" else "_v2"
633662
f0_str = "f0" if if_f0_3 else ""
634-
if_pretrained_generator_exist = os.access("pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2), os.F_OK)
635-
if_pretrained_discriminator_exist = os.access("pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2), os.F_OK)
636-
if (if_pretrained_generator_exist == False):
637-
print("pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2), "not exist, will not use pretrained model")
638-
if (if_pretrained_discriminator_exist == False):
639-
print("pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2), "not exist, will not use pretrained model")
663+
if_pretrained_generator_exist = os.access(
664+
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2), os.F_OK
665+
)
666+
if_pretrained_discriminator_exist = os.access(
667+
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2), os.F_OK
668+
)
669+
if not if_pretrained_generator_exist:
670+
print(
671+
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2),
672+
"not exist, will not use pretrained model",
673+
)
674+
if not if_pretrained_discriminator_exist:
675+
print(
676+
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2),
677+
"not exist, will not use pretrained model",
678+
)
640679
return (
641-
("pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2)) if if_pretrained_generator_exist else "",
642-
("pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2)) if if_pretrained_discriminator_exist else "",
680+
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2)
681+
if if_pretrained_generator_exist
682+
else "",
683+
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2)
684+
if if_pretrained_discriminator_exist
685+
else "",
643686
)
644687

645688

646689
def change_f0(if_f0_3, sr2, version19): # f0method8,pretrained_G14,pretrained_D15
647690
path_str = "" if version19 == "v1" else "_v2"
648-
if_pretrained_generator_exist = os.access("pretrained%s/f0G%s.pth" % (path_str, sr2), os.F_OK)
649-
if_pretrained_discriminator_exist = os.access("pretrained%s/f0D%s.pth" % (path_str, sr2), os.F_OK)
650-
if (if_pretrained_generator_exist == False):
651-
print("pretrained%s/f0G%s.pth" % (path_str, sr2), "not exist, will not use pretrained model")
652-
if (if_pretrained_discriminator_exist == False):
653-
print("pretrained%s/f0D%s.pth" % (path_str, sr2), "not exist, will not use pretrained model")
691+
if_pretrained_generator_exist = os.access(
692+
"pretrained%s/f0G%s.pth" % (path_str, sr2), os.F_OK
693+
)
694+
if_pretrained_discriminator_exist = os.access(
695+
"pretrained%s/f0D%s.pth" % (path_str, sr2), os.F_OK
696+
)
697+
if not if_pretrained_generator_exist:
698+
print(
699+
"pretrained%s/f0G%s.pth" % (path_str, sr2),
700+
"not exist, will not use pretrained model",
701+
)
702+
if not if_pretrained_discriminator_exist:
703+
print(
704+
"pretrained%s/f0D%s.pth" % (path_str, sr2),
705+
"not exist, will not use pretrained model",
706+
)
654707
if if_f0_3:
655708
return (
656709
{"visible": True, "__type__": "update"},
657-
"pretrained%s/f0G%s.pth" % (path_str, sr2) if if_pretrained_generator_exist else "",
658-
"pretrained%s/f0D%s.pth" % (path_str, sr2) if if_pretrained_discriminator_exist else "",
710+
"pretrained%s/f0G%s.pth" % (path_str, sr2)
711+
if if_pretrained_generator_exist
712+
else "",
713+
"pretrained%s/f0D%s.pth" % (path_str, sr2)
714+
if if_pretrained_discriminator_exist
715+
else "",
659716
)
660717
return (
661718
{"visible": False, "__type__": "update"},
662-
("pretrained%s/G%s.pth" % (path_str, sr2)) if if_pretrained_generator_exist else "",
663-
("pretrained%s/D%s.pth" % (path_str, sr2)) if if_pretrained_discriminator_exist else "",
719+
("pretrained%s/G%s.pth" % (path_str, sr2))
720+
if if_pretrained_generator_exist
721+
else "",
722+
("pretrained%s/D%s.pth" % (path_str, sr2))
723+
if if_pretrained_discriminator_exist
724+
else "",
664725
)
665726

666727

@@ -809,7 +870,7 @@ def train_index(exp_dir1, version19):
809870
if version19 == "v1"
810871
else "%s/3_feature768" % (exp_dir)
811872
)
812-
if os.path.exists(feature_dir) == False:
873+
if not os.path.exists(feature_dir):
813874
return "请先进行特征提取!"
814875
listdir_res = list(os.listdir(feature_dir))
815876
if len(listdir_res) == 0:
@@ -1014,7 +1075,7 @@ def get_info_str(strr):
10141075
if gpus16:
10151076
cmd = (
10161077
config.python_cmd
1017-
+" train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s"
1078+
+ " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s"
10181079
% (
10191080
exp_dir1,
10201081
sr2,
@@ -1098,10 +1159,7 @@ def get_info_str(strr):
10981159

10991160
# ckpt_path2.change(change_info_,[ckpt_path2],[sr__,if_f0__])
11001161
def change_info_(ckpt_path):
1101-
if (
1102-
os.path.exists(ckpt_path.replace(os.path.basename(ckpt_path), "train.log"))
1103-
== False
1104-
):
1162+
if not os.path.exists(ckpt_path.replace(os.path.basename(ckpt_path), "train.log")):
11051163
return {"__type__": "update"}, {"__type__": "update"}, {"__type__": "update"}
11061164
try:
11071165
with open(
@@ -1116,8 +1174,6 @@ def change_info_(ckpt_path):
11161174
return {"__type__": "update"}, {"__type__": "update"}, {"__type__": "update"}
11171175

11181176

1119-
from infer_pack.models_onnx import SynthesizerTrnMsNSFsidM
1120-
11211177

11221178
def export_onnx(ModelPath, ExportedPath):
11231179
cpt = torch.load(ModelPath, map_location="cpu")

0 commit comments

Comments
 (0)