Skip to content

Commit 8dc195b

Browse files
committed
fix
1 parent f27a991 commit 8dc195b

File tree

2 files changed

+29
-16
lines changed

2 files changed

+29
-16
lines changed

infer-web.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,11 @@ def click_train(exp_dir1,sr2,if_f0_3,spk_id5,save_epoch10,total_epoch11,batch_si
305305
print("write filelist done")
306306
#生成config#无需生成config
307307
# cmd = python_cmd + " train_nsf_sim_cache_sid_load_pretrain.py -e mi-test -sr 40k -f0 1 -bs 4 -g 0 -te 10 -se 5 -pg pretrained/f0G40k.pth -pd pretrained/f0D40k.pth -l 1 -c 0"
308-
cmd = python_cmd + " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s -pg %s -pd %s -l %s -c %s" % (exp_dir1,sr2,1 if if_f0_3=="是"else 0,batch_size12,gpus16,total_epoch11,save_epoch10,pretrained_G14,pretrained_D15,1 if if_save_latest13=="是"else 0,1 if if_cache_gpu17=="是"else 0)
308+
print("use gpus:",gpus16)
309+
if gpus16:
310+
cmd = python_cmd + " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s -pg %s -pd %s -l %s -c %s" % (exp_dir1,sr2,1 if if_f0_3=="是"else 0,batch_size12,gpus16,total_epoch11,save_epoch10,pretrained_G14,pretrained_D15,1 if if_save_latest13=="是"else 0,1 if if_cache_gpu17=="是"else 0)
311+
else:
312+
cmd = python_cmd + " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -te %s -se %s -pg %s -pd %s -l %s -c %s" % (exp_dir1,sr2,1 if if_f0_3=="是"else 0,batch_size12,total_epoch11,save_epoch10,pretrained_G14,pretrained_D15,1 if if_save_latest13=="是"else 0,1 if if_cache_gpu17=="是"else 0)
309313
print(cmd)
310314
p = Popen(cmd, shell=True, cwd=now_dir)
311315
p.wait()
@@ -398,7 +402,10 @@ def get_info_str(strr):
398402
opt.append("%s/%s.wav|%s/%s.npy|%s"%(gt_wavs_dir.replace("\\","\\\\"),name,co256_dir.replace("\\","\\\\"),name,spk_id5))
399403
with open("%s/filelist.txt"%exp_dir,"w")as f:f.write("\n".join(opt))
400404
yield get_info_str("write filelist done")
401-
cmd = python_cmd + " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s -pg %s -pd %s -l %s -c %s" % (exp_dir1,sr2,1 if if_f0_3=="是"else 0,batch_size12,gpus16,total_epoch11,save_epoch10,pretrained_G14,pretrained_D15,1 if if_save_latest13=="是"else 0,1 if if_cache_gpu17=="是"else 0)
405+
if gpus16:
406+
cmd = python_cmd + " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s -pg %s -pd %s -l %s -c %s" % (exp_dir1,sr2,1 if if_f0_3=="是"else 0,batch_size12,gpus16,total_epoch11,save_epoch10,pretrained_G14,pretrained_D15,1 if if_save_latest13=="是"else 0,1 if if_cache_gpu17=="是"else 0)
407+
else:
408+
cmd = python_cmd + " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -te %s -se %s -pg %s -pd %s -l %s -c %s" % (exp_dir1,sr2,1 if if_f0_3=="是"else 0,batch_size12,total_epoch11,save_epoch10,pretrained_G14,pretrained_D15,1 if if_save_latest13=="是"else 0,1 if if_cache_gpu17=="是"else 0)
402409
yield get_info_str(cmd)
403410
p = Popen(cmd, shell=True, cwd=now_dir)
404411
p.wait()

train_nsf_sim_cache_sid_load_pretrain.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@
3434

3535

3636
def main():
37-
"""Assume Single Node Multi GPUs Training Only"""
38-
assert torch.cuda.is_available(), "CPU training is not allowed."
39-
4037
# n_gpus = torch.cuda.device_count()
4138
os.environ["MASTER_ADDR"] = "localhost"
4239
os.environ["MASTER_PORT"] = "5555"
@@ -65,7 +62,7 @@ def run(rank, n_gpus, hps):
6562
backend="gloo", init_method="env://", world_size=n_gpus, rank=rank
6663
)
6764
torch.manual_seed(hps.train.seed)
68-
torch.cuda.set_device(rank)
65+
if torch.cuda.is_available(): torch.cuda.set_device(rank)
6966

7067
if (hps.if_f0 == 1):train_dataset = TextAudioLoaderMultiNSFsid(hps.data.training_files, hps.data)
7168
else:train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
@@ -92,9 +89,13 @@ def run(rank, n_gpus, hps):
9289
persistent_workers=True,
9390
prefetch_factor=8,
9491
)
95-
if(hps.if_f0==1):net_g = SynthesizerTrnMs256NSFsid(hps.data.filter_length // 2 + 1,hps.train.segment_size // hps.data.hop_length,**hps.model,is_half=hps.train.fp16_run,sr=hps.sample_rate).cuda(rank)
96-
else:net_g = SynthesizerTrnMs256NSFsid_nono(hps.data.filter_length // 2 + 1,hps.train.segment_size // hps.data.hop_length,**hps.model,is_half=hps.train.fp16_run).cuda(rank)
97-
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
92+
if(hps.if_f0==1):
93+
net_g = SynthesizerTrnMs256NSFsid(hps.data.filter_length // 2 + 1,hps.train.segment_size // hps.data.hop_length,**hps.model,is_half=hps.train.fp16_run,sr=hps.sample_rate)
94+
else:
95+
net_g = SynthesizerTrnMs256NSFsid_nono(hps.data.filter_length // 2 + 1,hps.train.segment_size // hps.data.hop_length,**hps.model,is_half=hps.train.fp16_run)
96+
if torch.cuda.is_available(): net_g = net_g.cuda(rank)
97+
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
98+
if torch.cuda.is_available(): net_d = net_d.cuda(rank)
9899
optim_g = torch.optim.AdamW(
99100
net_g.parameters(),
100101
hps.train.learning_rate,
@@ -109,8 +110,12 @@ def run(rank, n_gpus, hps):
109110
)
110111
# net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
111112
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
112-
net_g = DDP(net_g, device_ids=[rank])
113-
net_d = DDP(net_d, device_ids=[rank])
113+
if torch.cuda.is_available():
114+
net_g = DDP(net_g, device_ids=[rank])
115+
net_d = DDP(net_d, device_ids=[rank])
116+
else:
117+
net_g = DDP(net_g)
118+
net_d = DDP(net_d)
114119

115120
try:#如果能加载自动resume
116121
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d) # D多半加载没事
@@ -190,11 +195,12 @@ def train_and_evaluate(
190195
for batch_idx, info in enumerate(train_loader):
191196
if (hps.if_f0 == 1):phone,phone_lengths,pitch,pitchf,spec,spec_lengths,wave,wave_lengths,sid=info
192197
else:phone,phone_lengths,spec,spec_lengths,wave,wave_lengths,sid=info
193-
phone, phone_lengths = phone.cuda(rank, non_blocking=True),phone_lengths.cuda(rank, non_blocking=True )
194-
if (hps.if_f0 == 1):pitch,pitchf = pitch.cuda(rank, non_blocking=True),pitchf.cuda(rank, non_blocking=True)
195-
sid = sid.cuda(rank, non_blocking=True)
196-
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
197-
wave, wave_lengths = wave.cuda(rank, non_blocking=True), wave_lengths.cuda(rank, non_blocking=True)
198+
if torch.cuda.is_available():
199+
phone, phone_lengths = phone.cuda(rank, non_blocking=True), phone_lengths.cuda(rank, non_blocking=True )
200+
if (hps.if_f0 == 1):pitch,pitchf = pitch.cuda(rank, non_blocking=True),pitchf.cuda(rank, non_blocking=True)
201+
sid = sid.cuda(rank, non_blocking=True)
202+
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True)
203+
wave, wave_lengths = wave.cuda(rank, non_blocking=True), wave_lengths.cuda(rank, non_blocking=True)
198204
if(hps.if_cache_data_in_gpu==True):
199205
if (hps.if_f0 == 1):cache.append((batch_idx, (phone,phone_lengths,pitch,pitchf,spec,spec_lengths,wave,wave_lengths ,sid)))
200206
else:cache.append((batch_idx, (phone,phone_lengths,spec,spec_lengths,wave,wave_lengths ,sid)))

0 commit comments

Comments
 (0)