@@ -42,15 +42,15 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
4242 prev_ada_p , step , best_step , best_fid , best_fid_checkpoint_path , mu , sigma , inception_model = None , 0 , 0 , None , None , None , None , None
4343
4444 if cfgs .distributed_data_parallel :
45- rank = cfgs .nr * (gpus_per_node ) + local_rank
46- print ("Use GPU: {} for training." .format (rank ))
47- setup (rank , world_size )
48- torch .cuda .set_device (rank )
45+ global_rank = cfgs .nr * (gpus_per_node ) + local_rank
46+ print ("Use GPU: {} for training." .format (global_rank ))
47+ setup (global_rank , world_size )
48+ torch .cuda .set_device (local_rank )
4949 else :
50- rank = local_rank
50+ global_rank = local_rank
5151
52- writer = SummaryWriter (log_dir = join ('./logs' , run_name )) if rank == 0 else None
53- if rank == 0 :
52+ writer = SummaryWriter (log_dir = join ('./logs' , run_name )) if global_rank == 0 else None
53+ if global_rank == 0 :
5454 logger = make_logger (run_name , None )
5555 logger .info ('Run name : {run_name}' .format (run_name = run_name ))
5656 logger .info (train_config )
@@ -59,19 +59,19 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
5959 logger = None
6060
6161 ##### load dataset #####
62- if rank == 0 : logger .info ('Load train datasets...' )
62+ if global_rank == 0 : logger .info ('Load train datasets...' )
6363 train_dataset = LoadDataset (cfgs .dataset_name , cfgs .data_path , train = True , download = True , resize_size = cfgs .img_size ,
6464 hdf5_path = hdf5_path_train , random_flip = cfgs .random_flip_preprocessing )
6565 if cfgs .reduce_train_dataset < 1.0 :
6666 num_train = int (cfgs .reduce_train_dataset * len (train_dataset ))
6767 train_dataset , _ = torch .utils .data .random_split (train_dataset , [num_train , len (train_dataset ) - num_train ])
68- if rank == 0 : logger .info ('Train dataset size : {dataset_size}' .format (dataset_size = len (train_dataset )))
68+ if global_rank == 0 : logger .info ('Train dataset size : {dataset_size}' .format (dataset_size = len (train_dataset )))
6969
70- if rank == 0 : logger .info ('Load {mode} datasets...' .format (mode = cfgs .eval_type ))
70+ if global_rank == 0 : logger .info ('Load {mode} datasets...' .format (mode = cfgs .eval_type ))
7171 eval_mode = True if cfgs .eval_type == 'train' else False
7272 eval_dataset = LoadDataset (cfgs .dataset_name , cfgs .data_path , train = eval_mode , download = True , resize_size = cfgs .img_size ,
7373 hdf5_path = None , random_flip = False )
74- if rank == 0 : logger .info ('Eval dataset size : {dataset_size}' .format (dataset_size = len (eval_dataset )))
74+ if global_rank == 0 : logger .info ('Eval dataset size : {dataset_size}' .format (dataset_size = len (eval_dataset )))
7575
7676 if cfgs .distributed_data_parallel :
7777 train_sampler = torch .utils .data .distributed .DistributedSampler (train_dataset )
@@ -84,31 +84,31 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
8484 eval_dataloader = DataLoader (eval_dataset , batch_size = cfgs .batch_size , shuffle = False , pin_memory = True , num_workers = cfgs .num_workers , drop_last = False )
8585
8686 ##### build model #####
87- if rank == 0 : logger .info ('Build model...' )
87+ if global_rank == 0 : logger .info ('Build model...' )
8888 module = __import__ ('models.{architecture}' .format (architecture = cfgs .architecture ), fromlist = ['something' ])
89- if rank == 0 : logger .info ('Modules are located on models.{architecture}.' .format (architecture = cfgs .architecture ))
89+ if global_rank == 0 : logger .info ('Modules are located on models.{architecture}.' .format (architecture = cfgs .architecture ))
9090 Gen = module .Generator (cfgs .z_dim , cfgs .shared_dim , cfgs .img_size , cfgs .g_conv_dim , cfgs .g_spectral_norm , cfgs .attention ,
9191 cfgs .attention_after_nth_gen_block , cfgs .activation_fn , cfgs .conditional_strategy , cfgs .num_classes ,
92- cfgs .g_init , cfgs .G_depth , cfgs .mixed_precision ).to (rank )
92+ cfgs .g_init , cfgs .G_depth , cfgs .mixed_precision ).to (local_rank )
9393
9494 Dis = module .Discriminator (cfgs .img_size , cfgs .d_conv_dim , cfgs .d_spectral_norm , cfgs .attention , cfgs .attention_after_nth_dis_block ,
9595 cfgs .activation_fn , cfgs .conditional_strategy , cfgs .hypersphere_dim , cfgs .num_classes , cfgs .nonlinear_embed ,
96- cfgs .normalize_embed , cfgs .d_init , cfgs .D_depth , cfgs .mixed_precision ).to (rank )
96+ cfgs .normalize_embed , cfgs .d_init , cfgs .D_depth , cfgs .mixed_precision ).to (local_rank )
9797
9898 if cfgs .ema :
99- if rank == 0 : logger .info ('Prepare EMA for G with decay of {}.' .format (cfgs .ema_decay ))
99+ if global_rank == 0 : logger .info ('Prepare EMA for G with decay of {}.' .format (cfgs .ema_decay ))
100100 Gen_copy = module .Generator (cfgs .z_dim , cfgs .shared_dim , cfgs .img_size , cfgs .g_conv_dim , cfgs .g_spectral_norm , cfgs .attention ,
101101 cfgs .attention_after_nth_gen_block , cfgs .activation_fn , cfgs .conditional_strategy , cfgs .num_classes ,
102- initialize = False , G_depth = cfgs .G_depth , mixed_precision = cfgs .mixed_precision ).to (rank )
102+ initialize = False , G_depth = cfgs .G_depth , mixed_precision = cfgs .mixed_precision ).to (local_rank )
103103 Gen_ema = ema (Gen , Gen_copy , cfgs .ema_decay , cfgs .ema_start )
104104 else :
105105 Gen_copy , Gen_ema = None , None
106106
107- if rank == 0 : logger .info (count_parameters (Gen ))
108- if rank == 0 : logger .info (Gen )
107+ if global_rank == 0 : logger .info (count_parameters (Gen ))
108+ if global_rank == 0 : logger .info (Gen )
109109
110- if rank == 0 : logger .info (count_parameters (Dis ))
111- if rank == 0 : logger .info (Dis )
110+ if global_rank == 0 : logger .info (count_parameters (Dis ))
111+ if global_rank == 0 : logger .info (Dis )
112112
113113
114114 ### define loss functions and optimizers
@@ -144,18 +144,18 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
144144 Gen , G_optimizer , trained_seed , run_name , step , prev_ada_p = load_checkpoint (Gen , G_optimizer , g_checkpoint_dir )
145145 Dis , D_optimizer , trained_seed , run_name , step , prev_ada_p , best_step , best_fid , best_fid_checkpoint_path = \
146146 load_checkpoint (Dis , D_optimizer , d_checkpoint_dir , metric = True )
147- if rank == 0 : logger = make_logger (run_name , None )
147+ if global_rank == 0 : logger = make_logger (run_name , None )
148148 if cfgs .ema :
149149 g_ema_checkpoint_dir = glob .glob (join (checkpoint_dir , "model=G_ema-{when}-weights-step*.pth" .format (when = when )))[0 ]
150150 Gen_copy = load_checkpoint (Gen_copy , None , g_ema_checkpoint_dir , ema = True )
151151 Gen_ema .source , Gen_ema .target = Gen , Gen_copy
152152
153- writer = SummaryWriter (log_dir = join ('./logs' , run_name )) if rank == 0 else None
153+ writer = SummaryWriter (log_dir = join ('./logs' , run_name )) if global_rank == 0 else None
154154 if cfgs .train_configs ['train' ]:
155155 assert cfgs .seed == trained_seed , "Seed for sampling random numbers should be same!"
156156
157- if rank == 0 : logger .info ('Generator checkpoint is {}' .format (g_checkpoint_dir ))
158- if rank == 0 : logger .info ('Discriminator checkpoint is {}' .format (d_checkpoint_dir ))
157+ if global_rank == 0 : logger .info ('Generator checkpoint is {}' .format (g_checkpoint_dir ))
158+ if global_rank == 0 : logger .info ('Discriminator checkpoint is {}' .format (d_checkpoint_dir ))
159159 if cfgs .freeze_layers > - 1 :
160160 prev_ada_p , step , best_step , best_fid , best_fid_checkpoint_path = None , 0 , 0 , None , None
161161
@@ -170,30 +170,30 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
170170 if cfgs .ema :
171171 Gen_copy = torch .nn .SyncBatchNorm .convert_sync_batchnorm (Gen_copy , process_group )
172172
173- Gen = DDP (Gen , device_ids = [rank ])
174- Dis = DDP (Dis , device_ids = [rank ])
173+ Gen = DDP (Gen , device_ids = [local_rank ])
174+ Dis = DDP (Dis , device_ids = [local_rank ])
175175 if cfgs .ema :
176- Gen_copy = DDP (Gen_copy , device_ids = [rank ])
176+ Gen_copy = DDP (Gen_copy , device_ids = [local_rank ])
177177 else :
178- Gen = DataParallel (Gen , output_device = rank )
179- Dis = DataParallel (Dis , output_device = rank )
178+ Gen = DataParallel (Gen , output_device = local_rank )
179+ Dis = DataParallel (Dis , output_device = local_rank )
180180 if cfgs .ema :
181- Gen_copy = DataParallel (Gen_copy , output_device = rank )
181+ Gen_copy = DataParallel (Gen_copy , output_device = local_rank )
182182
183183 if cfgs .synchronized_bn :
184- Gen = convert_model (Gen ).to (rank )
185- Dis = convert_model (Dis ).to (rank )
184+ Gen = convert_model (Gen ).to (local_rank )
185+ Dis = convert_model (Dis ).to (local_rank )
186186 if cfgs .ema :
187- Gen_copy = convert_model (Gen_copy ).to (rank )
187+ Gen_copy = convert_model (Gen_copy ).to (local_rank )
188188
189189 ##### load the inception network and prepare first/secend moments for calculating FID #####
190190 if cfgs .eval :
191- inception_model = InceptionV3 ().to (rank )
191+ inception_model = InceptionV3 ().to (local_rank )
192192 if world_size > 1 and cfgs .distributed_data_parallel :
193193 toggle_grad (inception_model , on = True )
194- inception_model = DDP (inception_model , device_ids = [rank ], broadcast_buffers = False , find_unused_parameters = True )
194+ inception_model = DDP (inception_model , device_ids = [local_rank ], broadcast_buffers = False , find_unused_parameters = True )
195195 elif world_size > 1 and cfgs .distributed_data_parallel is False :
196- inception_model = DataParallel (inception_model , output_device = rank )
196+ inception_model = DataParallel (inception_model , output_device = local_rank )
197197 else :
198198 pass
199199
@@ -204,7 +204,7 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
204204 splits = 1 ,
205205 run_name = run_name ,
206206 logger = logger ,
207- device = rank )
207+ device = local_rank )
208208
209209 worker = make_worker (
210210 cfgs = cfgs ,
@@ -227,7 +227,8 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
227227 G_loss = G_loss [cfgs .adv_loss ],
228228 D_loss = D_loss [cfgs .adv_loss ],
229229 prev_ada_p = prev_ada_p ,
230- rank = rank ,
230+ global_rank = global_rank ,
231+ local_rank = local_rank ,
231232 bn_stat_OnTheFly = cfgs .bn_stat_OnTheFly ,
232233 checkpoint_dir = checkpoint_dir ,
233234 mu = mu ,
0 commit comments