Skip to content

Commit f349adc

Browse files
authored
Add support for train without specify pretrained model, add support for selecting v2 48k as training setting, and add support for auto remove pretrained model when the user do not have pretrained model in designate folder. (#528)
* support detection of pretrained model, support train without pretrained model path in web ui * support detection of pretrained model, support train without pretrained model path in web ui * support detection of pretrained model, support train without pretrained model path in web ui
1 parent eb1a88c commit f349adc

File tree

2 files changed

+62
-42
lines changed

2 files changed

+62
-42
lines changed

infer-web.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -614,43 +614,53 @@ def extract_f0_feature(gpus, n_p, f0method, if_f0, exp_dir, version19):
614614

615615

616616
def change_sr2(sr2, if_f0_3, version19):
617-
vis_v = True if sr2 == "40k" else False
618-
if sr2 != "40k":
619-
version19 = "v1"
620617
path_str = "" if version19 == "v1" else "_v2"
621-
version_state = {"visible": vis_v, "__type__": "update"}
622-
if vis_v == False:
623-
version_state["value"] = "v1"
624618
f0_str = "f0" if if_f0_3 else ""
619+
if_pretrained_generator_exist = os.access("pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2), os.F_OK)
620+
if_pretrained_discriminator_exist = os.access("pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2), os.F_OK)
621+
if (if_pretrained_generator_exist == False):
622+
print("pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2), "not exist, will not use pretrained model")
623+
if (if_pretrained_discriminator_exist == False):
624+
print("pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2), "not exist, will not use pretrained model")
625625
return (
626-
"pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2),
627-
"pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2),
628-
version_state,
626+
("pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2)) if if_pretrained_generator_exist else "",
627+
("pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2)) if if_pretrained_discriminator_exist else "",
628+
{"visible": True, "__type__": "update"}
629629
)
630630

631-
632631
def change_version19(sr2, if_f0_3, version19):
633632
path_str = "" if version19 == "v1" else "_v2"
634633
f0_str = "f0" if if_f0_3 else ""
635-
return "pretrained%s/%sG%s.pth" % (
636-
path_str,
637-
f0_str,
638-
sr2,
639-
), "pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2)
634+
if_pretrained_generator_exist = os.access("pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2), os.F_OK)
635+
if_pretrained_discriminator_exist = os.access("pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2), os.F_OK)
636+
if (if_pretrained_generator_exist == False):
637+
print("pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2), "not exist, will not use pretrained model")
638+
if (if_pretrained_discriminator_exist == False):
639+
print("pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2), "not exist, will not use pretrained model")
640+
return (
641+
("pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2)) if if_pretrained_generator_exist else "",
642+
("pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2)) if if_pretrained_discriminator_exist else "",
643+
)
640644

641645

642646
def change_f0(if_f0_3, sr2, version19): # f0method8,pretrained_G14,pretrained_D15
643647
path_str = "" if version19 == "v1" else "_v2"
648+
if_pretrained_generator_exist = os.access("pretrained%s/f0G%s.pth" % (path_str, sr2), os.F_OK)
649+
if_pretrained_discriminator_exist = os.access("pretrained%s/f0D%s.pth" % (path_str, sr2), os.F_OK)
650+
if (if_pretrained_generator_exist == False):
651+
print("pretrained%s/f0G%s.pth" % (path_str, sr2), "not exist, will not use pretrained model")
652+
if (if_pretrained_discriminator_exist == False):
653+
print("pretrained%s/f0D%s.pth" % (path_str, sr2), "not exist, will not use pretrained model")
644654
if if_f0_3:
645655
return (
646656
{"visible": True, "__type__": "update"},
647-
"pretrained%s/f0G%s.pth" % (path_str, sr2),
648-
"pretrained%s/f0D%s.pth" % (path_str, sr2),
657+
"pretrained%s/f0G%s.pth" % (path_str, sr2) if if_pretrained_generator_exist else "",
658+
"pretrained%s/f0D%s.pth" % (path_str, sr2) if if_pretrained_discriminator_exist else "",
649659
)
650660
return (
651661
{"visible": False, "__type__": "update"},
652-
"pretrained%s/G%s.pth" % (path_str, sr2),
653-
"pretrained%s/D%s.pth" % (path_str, sr2),
662+
("pretrained%s/G%s.pth" % (path_str, sr2)) if if_pretrained_generator_exist else "",
663+
("pretrained%s/D%s.pth" % (path_str, sr2)) if if_pretrained_discriminator_exist else "",
654664
)
655665

656666

@@ -741,10 +751,14 @@ def click_train(
741751
# 生成config#无需生成config
742752
# cmd = python_cmd + " train_nsf_sim_cache_sid_load_pretrain.py -e mi-test -sr 40k -f0 1 -bs 4 -g 0 -te 10 -se 5 -pg pretrained/f0G40k.pth -pd pretrained/f0D40k.pth -l 1 -c 0"
743753
print("use gpus:", gpus16)
754+
if pretrained_G14 == "":
755+
print("no pretrained Generator")
756+
if pretrained_D15 == "":
757+
print("no pretrained Discriminator")
744758
if gpus16:
745759
cmd = (
746760
config.python_cmd
747-
+ " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s -pg %s -pd %s -l %s -c %s -sw %s -v %s"
761+
+ " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s"
748762
% (
749763
exp_dir1,
750764
sr2,
@@ -753,8 +767,8 @@ def click_train(
753767
gpus16,
754768
total_epoch11,
755769
save_epoch10,
756-
pretrained_G14,
757-
pretrained_D15,
770+
("-pg %s" % pretrained_G14) if pretrained_G14 != "" else "",
771+
("-pd %s" % pretrained_D15) if pretrained_D15 != "" else "",
758772
1 if if_save_latest13 == i18n("是") else 0,
759773
1 if if_cache_gpu17 == i18n("是") else 0,
760774
1 if if_save_every_weights18 == i18n("是") else 0,
@@ -764,16 +778,16 @@ def click_train(
764778
else:
765779
cmd = (
766780
config.python_cmd
767-
+ " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -te %s -se %s -pg %s -pd %s -l %s -c %s -sw %s -v %s"
781+
+ " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s"
768782
% (
769783
exp_dir1,
770784
sr2,
771785
1 if if_f0_3 else 0,
772786
batch_size12,
773787
total_epoch11,
774788
save_epoch10,
775-
pretrained_G14,
776-
pretrained_D15,
789+
("-pg %s" % pretrained_G14) if pretrained_G14 != "" else "\b",
790+
("-pd %s" % pretrained_D15) if pretrained_D15 != "" else "\b",
777791
1 if if_save_latest13 == i18n("是") else 0,
778792
1 if if_cache_gpu17 == i18n("是") else 0,
779793
1 if if_save_every_weights18 == i18n("是") else 0,
@@ -1000,7 +1014,7 @@ def get_info_str(strr):
10001014
if gpus16:
10011015
cmd = (
10021016
config.python_cmd
1003-
+ " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s -pg %s -pd %s -l %s -c %s -sw %s -v %s"
1017+
+" train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s"
10041018
% (
10051019
exp_dir1,
10061020
sr2,
@@ -1009,8 +1023,8 @@ def get_info_str(strr):
10091023
gpus16,
10101024
total_epoch11,
10111025
save_epoch10,
1012-
pretrained_G14,
1013-
pretrained_D15,
1026+
("-pg %s" % pretrained_G14) if pretrained_G14 != "" else "",
1027+
("-pd %s" % pretrained_D15) if pretrained_D15 != "" else "",
10141028
1 if if_save_latest13 == i18n("是") else 0,
10151029
1 if if_cache_gpu17 == i18n("是") else 0,
10161030
1 if if_save_every_weights18 == i18n("是") else 0,
@@ -1020,16 +1034,16 @@ def get_info_str(strr):
10201034
else:
10211035
cmd = (
10221036
config.python_cmd
1023-
+ " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -te %s -se %s -pg %s -pd %s -l %s -c %s -sw %s -v %s"
1037+
+ " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s"
10241038
% (
10251039
exp_dir1,
10261040
sr2,
10271041
1 if if_f0_3 else 0,
10281042
batch_size12,
10291043
total_epoch11,
10301044
save_epoch10,
1031-
pretrained_G14,
1032-
pretrained_D15,
1045+
("-pg %s" % pretrained_G14) if pretrained_G14 != "" else "",
1046+
("-pd %s" % pretrained_D15) if pretrained_D15 != "" else "",
10331047
1 if if_save_latest13 == i18n("是") else 0,
10341048
1 if if_cache_gpu17 == i18n("是") else 0,
10351049
1 if if_save_every_weights18 == i18n("是") else 0,

train_nsf_sim_cache_sid_load_pretrain.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,18 +191,24 @@ def run(rank, n_gpus, hps):
191191
# traceback.print_exc()
192192
epoch_str = 1
193193
global_step = 0
194-
if rank == 0:
195-
logger.info("loaded pretrained %s %s" % (hps.pretrainG, hps.pretrainD))
196-
print(
197-
net_g.module.load_state_dict(
198-
torch.load(hps.pretrainG, map_location="cpu")["model"]
199-
)
200-
) ##测试不加载优化器
201-
print(
202-
net_d.module.load_state_dict(
203-
torch.load(hps.pretrainD, map_location="cpu")["model"]
194+
if hps.pretrainG != "":
195+
196+
if rank == 0:
197+
logger.info("loaded pretrained %s" % (hps.pretrainG))
198+
print(
199+
net_g.module.load_state_dict(
200+
torch.load(hps.pretrainG, map_location="cpu")["model"]
201+
)
202+
) ##测试不加载优化器
203+
if hps.pretrainD != "":
204+
205+
if rank == 0:
206+
logger.info("loaded pretrained %s" % (hps.pretrainD))
207+
print(
208+
net_d.module.load_state_dict(
209+
torch.load(hps.pretrainD, map_location="cpu")["model"]
210+
)
204211
)
205-
)
206212

207213
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
208214
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2

0 commit comments

Comments
 (0)