34
34
35
35
36
36
def main ():
37
- """Assume Single Node Multi GPUs Training Only"""
38
- assert torch .cuda .is_available (), "CPU training is not allowed."
39
-
40
37
# n_gpus = torch.cuda.device_count()
41
38
os .environ ["MASTER_ADDR" ] = "localhost"
42
39
os .environ ["MASTER_PORT" ] = "5555"
@@ -65,7 +62,7 @@ def run(rank, n_gpus, hps):
65
62
backend = "gloo" , init_method = "env://" , world_size = n_gpus , rank = rank
66
63
)
67
64
torch .manual_seed (hps .train .seed )
68
- torch .cuda .set_device (rank )
65
+ if torch . cuda . is_available (): torch .cuda .set_device (rank )
69
66
70
67
if (hps .if_f0 == 1 ):train_dataset = TextAudioLoaderMultiNSFsid (hps .data .training_files , hps .data )
71
68
else :train_dataset = TextAudioLoader (hps .data .training_files , hps .data )
@@ -92,9 +89,13 @@ def run(rank, n_gpus, hps):
92
89
persistent_workers = True ,
93
90
prefetch_factor = 8 ,
94
91
)
95
- if (hps .if_f0 == 1 ):net_g = SynthesizerTrnMs256NSFsid (hps .data .filter_length // 2 + 1 ,hps .train .segment_size // hps .data .hop_length ,** hps .model ,is_half = hps .train .fp16_run ,sr = hps .sample_rate ).cuda (rank )
96
- else :net_g = SynthesizerTrnMs256NSFsid_nono (hps .data .filter_length // 2 + 1 ,hps .train .segment_size // hps .data .hop_length ,** hps .model ,is_half = hps .train .fp16_run ).cuda (rank )
97
- net_d = MultiPeriodDiscriminator (hps .model .use_spectral_norm ).cuda (rank )
92
+ if (hps .if_f0 == 1 ):
93
+ net_g = SynthesizerTrnMs256NSFsid (hps .data .filter_length // 2 + 1 ,hps .train .segment_size // hps .data .hop_length ,** hps .model ,is_half = hps .train .fp16_run ,sr = hps .sample_rate )
94
+ else :
95
+ net_g = SynthesizerTrnMs256NSFsid_nono (hps .data .filter_length // 2 + 1 ,hps .train .segment_size // hps .data .hop_length ,** hps .model ,is_half = hps .train .fp16_run )
96
+ if torch .cuda .is_available (): net_g = net_g .cuda (rank )
97
+ net_d = MultiPeriodDiscriminator (hps .model .use_spectral_norm )
98
+ if torch .cuda .is_available (): net_d = net_d .cuda (rank )
98
99
optim_g = torch .optim .AdamW (
99
100
net_g .parameters (),
100
101
hps .train .learning_rate ,
@@ -109,8 +110,12 @@ def run(rank, n_gpus, hps):
109
110
)
110
111
# net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
111
112
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
112
- net_g = DDP (net_g , device_ids = [rank ])
113
- net_d = DDP (net_d , device_ids = [rank ])
113
+ if torch .cuda .is_available ():
114
+ net_g = DDP (net_g , device_ids = [rank ])
115
+ net_d = DDP (net_d , device_ids = [rank ])
116
+ else :
117
+ net_g = DDP (net_g )
118
+ net_d = DDP (net_d )
114
119
115
120
try :#如果能加载自动resume
116
121
_ , _ , _ , epoch_str = utils .load_checkpoint (utils .latest_checkpoint_path (hps .model_dir , "D_*.pth" ), net_d , optim_d ) # D多半加载没事
@@ -190,11 +195,12 @@ def train_and_evaluate(
190
195
for batch_idx , info in enumerate (train_loader ):
191
196
if (hps .if_f0 == 1 ):phone ,phone_lengths ,pitch ,pitchf ,spec ,spec_lengths ,wave ,wave_lengths ,sid = info
192
197
else :phone ,phone_lengths ,spec ,spec_lengths ,wave ,wave_lengths ,sid = info
193
- phone , phone_lengths = phone .cuda (rank , non_blocking = True ),phone_lengths .cuda (rank , non_blocking = True )
194
- if (hps .if_f0 == 1 ):pitch ,pitchf = pitch .cuda (rank , non_blocking = True ),pitchf .cuda (rank , non_blocking = True )
195
- sid = sid .cuda (rank , non_blocking = True )
196
- spec , spec_lengths = spec .cuda (rank , non_blocking = True ), spec_lengths .cuda (rank , non_blocking = True )
197
- wave , wave_lengths = wave .cuda (rank , non_blocking = True ), wave_lengths .cuda (rank , non_blocking = True )
198
+ if torch .cuda .is_available ():
199
+ phone , phone_lengths = phone .cuda (rank , non_blocking = True ), phone_lengths .cuda (rank , non_blocking = True )
200
+ if (hps .if_f0 == 1 ):pitch ,pitchf = pitch .cuda (rank , non_blocking = True ),pitchf .cuda (rank , non_blocking = True )
201
+ sid = sid .cuda (rank , non_blocking = True )
202
+ spec , spec_lengths = spec .cuda (rank , non_blocking = True ), spec_lengths .cuda (rank , non_blocking = True )
203
+ wave , wave_lengths = wave .cuda (rank , non_blocking = True ), wave_lengths .cuda (rank , non_blocking = True )
198
204
if (hps .if_cache_data_in_gpu == True ):
199
205
if (hps .if_f0 == 1 ):cache .append ((batch_idx , (phone ,phone_lengths ,pitch ,pitchf ,spec ,spec_lengths ,wave ,wave_lengths ,sid )))
200
206
else :cache .append ((batch_idx , (phone ,phone_lengths ,spec ,spec_lengths ,wave ,wave_lengths ,sid )))
0 commit comments