Skip to content

Commit e7f204b

Browse files
authored
train index:auto kmeans when feature shape too large
train index:auto kmeans when feature shape too large
1 parent 75264d0 commit e7f204b

File tree

1 file changed

+47
-12
lines changed

1 file changed

+47
-12
lines changed

infer-web.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@
3535
from my_utils import load_audio
3636
from train.process_ckpt import change_info, extract_small_model, merge, show_info
3737
from vc_infer_pipeline import VC
38-
39-
# from trainset_preprocess_pipeline import PreProcess
38+
from sklearn.cluster import MiniBatchKMeans
4039

4140
logging.getLogger("numba").setLevel(logging.WARNING)
4241

@@ -653,9 +652,13 @@ def change_sr2(sr2, if_f0_3, version19):
653652
"not exist, will not use pretrained model",
654653
)
655654
return (
656-
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2) if if_pretrained_generator_exist else "",
657-
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2) if if_pretrained_discriminator_exist else "",
658-
{"visible": True, "__type__": "update"}
655+
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2)
656+
if if_pretrained_generator_exist
657+
else "",
658+
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2)
659+
if if_pretrained_discriminator_exist
660+
else "",
661+
{"visible": True, "__type__": "update"},
659662
)
660663

661664

@@ -679,8 +682,12 @@ def change_version19(sr2, if_f0_3, version19):
679682
"not exist, will not use pretrained model",
680683
)
681684
return (
682-
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2) if if_pretrained_generator_exist else "",
683-
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2) if if_pretrained_discriminator_exist else "",
685+
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2)
686+
if if_pretrained_generator_exist
687+
else "",
688+
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2)
689+
if if_pretrained_discriminator_exist
690+
else "",
684691
)
685692

686693

@@ -714,8 +721,12 @@ def change_f0(if_f0_3, sr2, version19): # f0method8,pretrained_G14,pretrained_D
714721
)
715722
return (
716723
{"visible": False, "__type__": "update"},
717-
"pretrained%s/G%s.pth" % (path_str, sr2) if if_pretrained_generator_exist else "",
718-
"pretrained%s/D%s.pth" % (path_str, sr2) if if_pretrained_discriminator_exist else "",
724+
("pretrained%s/G%s.pth" % (path_str, sr2))
725+
if if_pretrained_generator_exist
726+
else "",
727+
("pretrained%s/D%s.pth" % (path_str, sr2))
728+
if if_pretrained_discriminator_exist
729+
else "",
719730
)
720731

721732

@@ -869,6 +880,7 @@ def train_index(exp_dir1, version19):
869880
listdir_res = list(os.listdir(feature_dir))
870881
if len(listdir_res) == 0:
871882
return "请先进行特征提取!"
883+
infos = []
872884
npys = []
873885
for name in sorted(listdir_res):
874886
phone = np.load("%s/%s" % (feature_dir, name))
@@ -877,10 +889,20 @@ def train_index(exp_dir1, version19):
877889
big_npy_idx = np.arange(big_npy.shape[0])
878890
np.random.shuffle(big_npy_idx)
879891
big_npy = big_npy[big_npy_idx]
892+
# if(big_npy.shape[0]>2e5):
893+
if(1):
894+
infos.append("Trying doing kmeans %s shape to 10k centers."%big_npy.shape[0])
895+
yield "\n".join(infos)
896+
try:
897+
big_npy = MiniBatchKMeans(n_clusters=10000, verbose=True, batch_size=256 * config.n_cpu, compute_labels=False, init="random").fit(big_npy).cluster_centers_
898+
except:
899+
info=traceback.format_exc()
900+
print(info)
901+
infos.append(info)
902+
yield "\n".join(infos)
903+
880904
np.save("%s/total_fea.npy" % exp_dir, big_npy)
881-
# n_ivf = big_npy.shape[0] // 39
882905
n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
883-
infos = []
884906
infos.append("%s,%s" % (big_npy.shape, n_ivf))
885907
yield "\n".join(infos)
886908
index = faiss.index_factory(256 if version19 == "v1" else 768, "IVF%s,Flat" % n_ivf)
@@ -1120,6 +1142,19 @@ def get_info_str(strr):
11201142
big_npy_idx = np.arange(big_npy.shape[0])
11211143
np.random.shuffle(big_npy_idx)
11221144
big_npy = big_npy[big_npy_idx]
1145+
1146+
# if(big_npy.shape[0]>2e5):
1147+
if(1):
1148+
info="Trying doing kmeans %s shape to 10k centers."%big_npy.shape[0]
1149+
print(info)
1150+
yield get_info_str(info)
1151+
try:
1152+
big_npy = MiniBatchKMeans(n_clusters=10000, verbose=True, batch_size=256 * config.n_cpu, compute_labels=False, init="random").fit(big_npy).cluster_centers_
1153+
except:
1154+
info=traceback.format_exc()
1155+
print(info)
1156+
yield get_info_str(info)
1157+
11231158
np.save("%s/total_fea.npy" % model_log_dir, big_npy)
11241159

11251160
# n_ivf = big_npy.shape[0] // 39
@@ -1565,7 +1600,7 @@ def export_onnx(ModelPath, ExportedPath):
15651600
maximum=config.n_cpu,
15661601
step=1,
15671602
label=i18n("提取音高和处理数据使用的CPU进程数"),
1568-
value=config.n_cpu,
1603+
value=int(np.ceil(config.n_cpu/1.5)),
15691604
interactive=True,
15701605
)
15711606
with gr.Group(): # 暂时单人的, 后面支持最多4人的#数据处理

0 commit comments

Comments
 (0)