Skip to content

Commit 4c2ed8a

Browse files
committed
modify gpu allocation
1 parent 59158ed commit 4c2ed8a

File tree

2 files changed

+116
-113
lines changed

2 files changed

+116
-113
lines changed

src/loader.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
4242
prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path, mu, sigma, inception_model = None, 0, 0, None, None, None, None, None
4343

4444
if cfgs.distributed_data_parallel:
45-
rank = cfgs.nr*(gpus_per_node) + local_rank
46-
print("Use GPU: {} for training.".format(rank))
47-
setup(rank, world_size)
48-
torch.cuda.set_device(rank)
45+
global_rank = cfgs.nr*(gpus_per_node) + local_rank
46+
print("Use GPU: {} for training.".format(global_rank))
47+
setup(global_rank, world_size)
48+
torch.cuda.set_device(local_rank)
4949
else:
50-
rank = local_rank
50+
global_rank = local_rank
5151

52-
writer = SummaryWriter(log_dir=join('./logs', run_name)) if rank == 0 else None
53-
if rank == 0:
52+
writer = SummaryWriter(log_dir=join('./logs', run_name)) if global_rank == 0 else None
53+
if global_rank == 0:
5454
logger = make_logger(run_name, None)
5555
logger.info('Run name : {run_name}'.format(run_name=run_name))
5656
logger.info(train_config)
@@ -59,19 +59,19 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
5959
logger = None
6060

6161
##### load dataset #####
62-
if rank == 0: logger.info('Load train datasets...')
62+
if global_rank == 0: logger.info('Load train datasets...')
6363
train_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=True, download=True, resize_size=cfgs.img_size,
6464
hdf5_path=hdf5_path_train, random_flip=cfgs.random_flip_preprocessing)
6565
if cfgs.reduce_train_dataset < 1.0:
6666
num_train = int(cfgs.reduce_train_dataset*len(train_dataset))
6767
train_dataset, _ = torch.utils.data.random_split(train_dataset, [num_train, len(train_dataset) - num_train])
68-
if rank == 0: logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset)))
68+
if global_rank == 0: logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset)))
6969

70-
if rank == 0: logger.info('Load {mode} datasets...'.format(mode=cfgs.eval_type))
70+
if global_rank == 0: logger.info('Load {mode} datasets...'.format(mode=cfgs.eval_type))
7171
eval_mode = True if cfgs.eval_type == 'train' else False
7272
eval_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=eval_mode, download=True, resize_size=cfgs.img_size,
7373
hdf5_path=None, random_flip=False)
74-
if rank == 0: logger.info('Eval dataset size : {dataset_size}'.format(dataset_size=len(eval_dataset)))
74+
if global_rank == 0: logger.info('Eval dataset size : {dataset_size}'.format(dataset_size=len(eval_dataset)))
7575

7676
if cfgs.distributed_data_parallel:
7777
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
@@ -84,31 +84,31 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
8484
eval_dataloader = DataLoader(eval_dataset, batch_size=cfgs.batch_size, shuffle=False, pin_memory=True, num_workers=cfgs.num_workers, drop_last=False)
8585

8686
##### build model #####
87-
if rank == 0: logger.info('Build model...')
87+
if global_rank == 0: logger.info('Build model...')
8888
module = __import__('models.{architecture}'.format(architecture=cfgs.architecture), fromlist=['something'])
89-
if rank == 0: logger.info('Modules are located on models.{architecture}.'.format(architecture=cfgs.architecture))
89+
if global_rank == 0: logger.info('Modules are located on models.{architecture}.'.format(architecture=cfgs.architecture))
9090
Gen = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention,
9191
cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes,
92-
cfgs.g_init, cfgs.G_depth, cfgs.mixed_precision).to(rank)
92+
cfgs.g_init, cfgs.G_depth, cfgs.mixed_precision).to(local_rank)
9393

9494
Dis = module.Discriminator(cfgs.img_size, cfgs.d_conv_dim, cfgs.d_spectral_norm, cfgs.attention, cfgs.attention_after_nth_dis_block,
9595
cfgs.activation_fn, cfgs.conditional_strategy, cfgs.hypersphere_dim, cfgs.num_classes, cfgs.nonlinear_embed,
96-
cfgs.normalize_embed, cfgs.d_init, cfgs.D_depth, cfgs.mixed_precision).to(rank)
96+
cfgs.normalize_embed, cfgs.d_init, cfgs.D_depth, cfgs.mixed_precision).to(local_rank)
9797

9898
if cfgs.ema:
99-
if rank == 0: logger.info('Prepare EMA for G with decay of {}.'.format(cfgs.ema_decay))
99+
if global_rank == 0: logger.info('Prepare EMA for G with decay of {}.'.format(cfgs.ema_decay))
100100
Gen_copy = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention,
101101
cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes,
102-
initialize=False, G_depth=cfgs.G_depth, mixed_precision=cfgs.mixed_precision).to(rank)
102+
initialize=False, G_depth=cfgs.G_depth, mixed_precision=cfgs.mixed_precision).to(local_rank)
103103
Gen_ema = ema(Gen, Gen_copy, cfgs.ema_decay, cfgs.ema_start)
104104
else:
105105
Gen_copy, Gen_ema = None, None
106106

107-
if rank == 0: logger.info(count_parameters(Gen))
108-
if rank == 0: logger.info(Gen)
107+
if global_rank == 0: logger.info(count_parameters(Gen))
108+
if global_rank == 0: logger.info(Gen)
109109

110-
if rank == 0: logger.info(count_parameters(Dis))
111-
if rank == 0: logger.info(Dis)
110+
if global_rank == 0: logger.info(count_parameters(Dis))
111+
if global_rank == 0: logger.info(Dis)
112112

113113

114114
### define loss functions and optimizers
@@ -144,18 +144,18 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
144144
Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint(Gen, G_optimizer, g_checkpoint_dir)
145145
Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\
146146
load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True)
147-
if rank == 0: logger = make_logger(run_name, None)
147+
if global_rank == 0: logger = make_logger(run_name, None)
148148
if cfgs.ema:
149149
g_ema_checkpoint_dir = glob.glob(join(checkpoint_dir, "model=G_ema-{when}-weights-step*.pth".format(when=when)))[0]
150150
Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=True)
151151
Gen_ema.source, Gen_ema.target = Gen, Gen_copy
152152

153-
writer = SummaryWriter(log_dir=join('./logs', run_name)) if rank ==0 else None
153+
writer = SummaryWriter(log_dir=join('./logs', run_name)) if global_rank == 0 else None
154154
if cfgs.train_configs['train']:
155155
assert cfgs.seed == trained_seed, "Seed for sampling random numbers should be same!"
156156

157-
if rank == 0: logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir))
158-
if rank == 0: logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir))
157+
if global_rank == 0: logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir))
158+
if global_rank == 0: logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir))
159159
if cfgs.freeze_layers > -1 :
160160
prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None
161161

@@ -170,30 +170,30 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
170170
if cfgs.ema:
171171
Gen_copy = torch.nn.SyncBatchNorm.convert_sync_batchnorm(Gen_copy, process_group)
172172

173-
Gen = DDP(Gen, device_ids=[rank])
174-
Dis = DDP(Dis, device_ids=[rank])
173+
Gen = DDP(Gen, device_ids=[local_rank])
174+
Dis = DDP(Dis, device_ids=[local_rank])
175175
if cfgs.ema:
176-
Gen_copy = DDP(Gen_copy, device_ids=[rank])
176+
Gen_copy = DDP(Gen_copy, device_ids=[local_rank])
177177
else:
178-
Gen = DataParallel(Gen, output_device=rank)
179-
Dis = DataParallel(Dis, output_device=rank)
178+
Gen = DataParallel(Gen, output_device=local_rank)
179+
Dis = DataParallel(Dis, output_device=local_rank)
180180
if cfgs.ema:
181-
Gen_copy = DataParallel(Gen_copy, output_device=rank)
181+
Gen_copy = DataParallel(Gen_copy, output_device=local_rank)
182182

183183
if cfgs.synchronized_bn:
184-
Gen = convert_model(Gen).to(rank)
185-
Dis = convert_model(Dis).to(rank)
184+
Gen = convert_model(Gen).to(local_rank)
185+
Dis = convert_model(Dis).to(local_rank)
186186
if cfgs.ema:
187-
Gen_copy = convert_model(Gen_copy).to(rank)
187+
Gen_copy = convert_model(Gen_copy).to(local_rank)
188188

189189
##### load the inception network and prepare first/secend moments for calculating FID #####
190190
if cfgs.eval:
191-
inception_model = InceptionV3().to(rank)
191+
inception_model = InceptionV3().to(local_rank)
192192
if world_size > 1 and cfgs.distributed_data_parallel:
193193
toggle_grad(inception_model, on=True)
194-
inception_model = DDP(inception_model, device_ids=[rank], broadcast_buffers=False, find_unused_parameters=True)
194+
inception_model = DDP(inception_model, device_ids=[local_rank], broadcast_buffers=False, find_unused_parameters=True)
195195
elif world_size > 1 and cfgs.distributed_data_parallel is False:
196-
inception_model = DataParallel(inception_model, output_device=rank)
196+
inception_model = DataParallel(inception_model, output_device=local_rank)
197197
else:
198198
pass
199199

@@ -204,7 +204,7 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
204204
splits=1,
205205
run_name=run_name,
206206
logger=logger,
207-
device=rank)
207+
device=local_rank)
208208

209209
worker = make_worker(
210210
cfgs=cfgs,
@@ -227,7 +227,8 @@ def prepare_train_eval(local_rank, gpus_per_node, world_size, run_name, train_co
227227
G_loss=G_loss[cfgs.adv_loss],
228228
D_loss=D_loss[cfgs.adv_loss],
229229
prev_ada_p=prev_ada_p,
230-
rank=rank,
230+
global_rank=global_rank,
231+
local_rank=local_rank,
231232
bn_stat_OnTheFly=cfgs.bn_stat_OnTheFly,
232233
checkpoint_dir=checkpoint_dir,
233234
mu=mu,

0 commit comments

Comments
 (0)