@@ -49,8 +49,8 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
4949 else :
5050 global_rank = local_rank
5151
52- writer = SummaryWriter (log_dir = join ('./logs' , run_name )) if global_rank == 0 else None
53- if global_rank == 0 :
52+ writer = SummaryWriter (log_dir = join ('./logs' , run_name )) if local_rank == 0 else None
53+ if local_rank == 0 :
5454 logger = make_logger (run_name , None )
5555 logger .info ('Run name : {run_name}' .format (run_name = run_name ))
5656 logger .info (train_config )
@@ -59,19 +59,19 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
5959 logger = None
6060
6161 ##### load dataset #####
62- if global_rank == 0 : logger .info ('Load train datasets...' )
62+ if local_rank == 0 : logger .info ('Load train datasets...' )
6363 train_dataset = LoadDataset (cfgs .dataset_name , cfgs .data_path , train = True , download = True , resize_size = cfgs .img_size ,
6464 hdf5_path = hdf5_path_train , random_flip = cfgs .random_flip_preprocessing )
6565 if cfgs .reduce_train_dataset < 1.0 :
6666 num_train = int (cfgs .reduce_train_dataset * len (train_dataset ))
6767 train_dataset , _ = torch .utils .data .random_split (train_dataset , [num_train , len (train_dataset ) - num_train ])
68- if global_rank == 0 : logger .info ('Train dataset size : {dataset_size}' .format (dataset_size = len (train_dataset )))
68+ if local_rank == 0 : logger .info ('Train dataset size : {dataset_size}' .format (dataset_size = len (train_dataset )))
6969
70- if global_rank == 0 : logger .info ('Load {mode} datasets...' .format (mode = cfgs .eval_type ))
70+ if local_rank == 0 : logger .info ('Load {mode} datasets...' .format (mode = cfgs .eval_type ))
7171 eval_mode = True if cfgs .eval_type == 'train' else False
7272 eval_dataset = LoadDataset (cfgs .dataset_name , cfgs .data_path , train = eval_mode , download = True , resize_size = cfgs .img_size ,
7373 hdf5_path = None , random_flip = False )
74- if global_rank == 0 : logger .info ('Eval dataset size : {dataset_size}' .format (dataset_size = len (eval_dataset )))
74+ if local_rank == 0 : logger .info ('Eval dataset size : {dataset_size}' .format (dataset_size = len (eval_dataset )))
7575
7676 if cfgs .distributed_data_parallel :
7777 train_sampler = torch .utils .data .distributed .DistributedSampler (train_dataset )
@@ -84,9 +84,9 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
8484 eval_dataloader = DataLoader (eval_dataset , batch_size = cfgs .batch_size , shuffle = False , pin_memory = True , num_workers = cfgs .num_workers , drop_last = False )
8585
8686 ##### build model #####
87- if global_rank == 0 : logger .info ('Build model...' )
87+ if local_rank == 0 : logger .info ('Build model...' )
8888 module = __import__ ('models.{architecture}' .format (architecture = cfgs .architecture ), fromlist = ['something' ])
89- if global_rank == 0 : logger .info ('Modules are located on models.{architecture}.' .format (architecture = cfgs .architecture ))
89+ if local_rank == 0 : logger .info ('Modules are located on models.{architecture}.' .format (architecture = cfgs .architecture ))
9090 Gen = module .Generator (cfgs .z_dim , cfgs .shared_dim , cfgs .img_size , cfgs .g_conv_dim , cfgs .g_spectral_norm , cfgs .attention ,
9191 cfgs .attention_after_nth_gen_block , cfgs .activation_fn , cfgs .conditional_strategy , cfgs .num_classes ,
9292 cfgs .g_init , cfgs .G_depth , cfgs .mixed_precision ).to (local_rank )
@@ -96,19 +96,19 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
9696 cfgs .normalize_embed , cfgs .d_init , cfgs .D_depth , cfgs .mixed_precision ).to (local_rank )
9797
9898 if cfgs .ema :
99- if global_rank == 0 : logger .info ('Prepare EMA for G with decay of {}.' .format (cfgs .ema_decay ))
99+ if local_rank == 0 : logger .info ('Prepare EMA for G with decay of {}.' .format (cfgs .ema_decay ))
100100 Gen_copy = module .Generator (cfgs .z_dim , cfgs .shared_dim , cfgs .img_size , cfgs .g_conv_dim , cfgs .g_spectral_norm , cfgs .attention ,
101101 cfgs .attention_after_nth_gen_block , cfgs .activation_fn , cfgs .conditional_strategy , cfgs .num_classes ,
102102 initialize = False , G_depth = cfgs .G_depth , mixed_precision = cfgs .mixed_precision ).to (local_rank )
103103 Gen_ema = ema (Gen , Gen_copy , cfgs .ema_decay , cfgs .ema_start )
104104 else :
105105 Gen_copy , Gen_ema = None , None
106106
107- if global_rank == 0 : logger .info (count_parameters (Gen ))
108- if global_rank == 0 : logger .info (Gen )
107+ if local_rank == 0 : logger .info (count_parameters (Gen ))
108+ if local_rank == 0 : logger .info (Gen )
109109
110- if global_rank == 0 : logger .info (count_parameters (Dis ))
111- if global_rank == 0 : logger .info (Dis )
110+ if local_rank == 0 : logger .info (count_parameters (Dis ))
111+ if local_rank == 0 : logger .info (Dis )
112112
113113
114114 ### define loss functions and optimizers
@@ -144,7 +144,7 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
144144 Gen , G_optimizer , trained_seed , run_name , step , prev_ada_p = load_checkpoint (Gen , G_optimizer , g_checkpoint_dir )
145145 Dis , D_optimizer , trained_seed , run_name , step , prev_ada_p , best_step , best_fid , best_fid_checkpoint_path = \
146146 load_checkpoint (Dis , D_optimizer , d_checkpoint_dir , metric = True )
147- if global_rank == 0 : logger = make_logger (run_name , None )
147+ if local_rank == 0 : logger = make_logger (run_name , None )
148148 if cfgs .ema :
149149 g_ema_checkpoint_dir = glob .glob (join (checkpoint_dir , "model=G_ema-{when}-weights-step*.pth" .format (when = when )))[0 ]
150150 Gen_copy = load_checkpoint (Gen_copy , None , g_ema_checkpoint_dir , ema = True )
@@ -154,8 +154,8 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
154154 if cfgs .train_configs ['train' ]:
155155 assert cfgs .seed == trained_seed , "Seed for sampling random numbers should be same!"
156156
157- if global_rank == 0 : logger .info ('Generator checkpoint is {}' .format (g_checkpoint_dir ))
158- if global_rank == 0 : logger .info ('Discriminator checkpoint is {}' .format (d_checkpoint_dir ))
157+ if local_rank == 0 : logger .info ('Generator checkpoint is {}' .format (g_checkpoint_dir ))
158+ if local_rank == 0 : logger .info ('Discriminator checkpoint is {}' .format (d_checkpoint_dir ))
159159 if cfgs .freeze_layers > - 1 :
160160 prev_ada_p , step , best_step , best_fid , best_fid_checkpoint_path = None , 0 , 0 , None , None
161161
0 commit comments