Skip to content

Commit 9576dd2

Browse files
added load on CPU first (#221)
* added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added load on CPU first * added print logs * added print logs * changed close order * changed close order
1 parent 90353ac commit 9576dd2

File tree

4 files changed

+207
-51
lines changed

4 files changed

+207
-51
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def __init__(self,
170170

171171
# allow int, string and gpu list
172172
self.data_parallel_device_ids = self.__parse_gpu_ids(gpus)
173+
self.root_gpu = self.__set_root_gpu(self.data_parallel_device_ids)
173174

174175
# distributed backend choice
175176
self.use_ddp = False
@@ -270,6 +271,17 @@ def __parse_gpu_ids(self, gpus):
270271

271272
return gpus
272273

274+
def __set_root_gpu(self, gpus):
275+
if gpus is None:
276+
return None
277+
278+
# set root gpu
279+
root_gpu = 0
280+
if type(gpus) is list:
281+
root_gpu = gpus[0]
282+
283+
return root_gpu
284+
273285
@property
274286
def num_gpus(self):
275287
gpus = self.data_parallel_device_ids
@@ -701,10 +713,7 @@ def __single_gpu_train(self, model):
701713
# allow for lr schedulers as well
702714
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
703715

704-
root_gpu = 0
705-
if type(self.data_parallel_device_ids) is list:
706-
root_gpu = self.data_parallel_device_ids[0]
707-
model.cuda(root_gpu)
716+
model.cuda(self.root_gpu)
708717

709718
if self.use_amp:
710719
# An example
@@ -721,10 +730,7 @@ def __dp_train(self, model):
721730
# allow for lr schedulers as well
722731
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
723732

724-
root_gpu = 0
725-
if type(self.data_parallel_device_ids) is list:
726-
root_gpu = self.data_parallel_device_ids[0]
727-
model.cuda(root_gpu)
733+
model.cuda(self.root_gpu)
728734

729735
# check for this bug (amp + dp + !01 doesn't work)
730736
# https://github.com/NVIDIA/apex/issues/227
@@ -736,7 +742,12 @@ def __dp_train(self, model):
736742
"""
737743
raise MisconfigurationException(m)
738744

739-
model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids)
745+
# create list of device ids
746+
device_ids = self.data_parallel_device_ids
747+
if type(device_ids) is int:
748+
device_ids = list(range(device_ids))
749+
750+
model = LightningDataParallel(model, device_ids=device_ids)
740751

741752
self.__run_pretrain_routine(model)
742753

@@ -787,6 +798,9 @@ def ddp_train(self, gpu_nb, model):
787798
torch.cuda.set_device(gpu_nb)
788799
model.cuda(gpu_nb)
789800

801+
# override root GPU
802+
self.root_gpu = gpu_nb
803+
790804
# AMP
791805
# run through amp wrapper before going to distributed DP
792806
if self.use_amp:

pytorch_lightning/trainer/trainer_io.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import re
33
import signal
4+
import pdb
45
from subprocess import call
56

67
import torch
@@ -78,7 +79,7 @@ def register_slurm_signal_handlers(self):
7879
except Exception as e:
7980
pass
8081

81-
if on_slurm and self.proc_rank == 0:
82+
if on_slurm:
8283
print('set slurm handle signals')
8384
signal.signal(signal.SIGUSR1, self.sig_handler)
8485
signal.signal(signal.SIGTERM, self.term_handler)
@@ -103,6 +104,9 @@ def sig_handler(self, signum, frame):
103104
else:
104105
print('requeue failed...')
105106

107+
# close experiment to avoid issues
108+
self.experiment.close()
109+
106110
def term_handler(self, signum, frame):
107111
# save
108112
print("bypassing sigterm")
@@ -118,19 +122,22 @@ def save_checkpoint(self, filepath):
118122

119123
def restore(self, checkpoint_path, on_gpu):
120124

121-
if on_gpu:
122-
checkpoint = torch.load(checkpoint_path)
123-
else:
124-
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
125-
126-
# load training state (affects trainer only)
127-
self.restore_training_state(checkpoint)
125+
# if on_gpu:
126+
# checkpoint = torch.load(checkpoint_path)
127+
# else:
128+
# load on CPU first
129+
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
128130

129131
# load model state
130132
model = self.__get_model()
131133

132134
# load the state_dict on the model automatically
133135
model.load_state_dict(checkpoint['state_dict'])
136+
if on_gpu:
137+
model.cuda(self.root_gpu)
138+
139+
# load training state (affects trainer only)
140+
self.restore_training_state(checkpoint)
134141

135142
def dump_checkpoint(self):
136143

@@ -210,6 +217,14 @@ def restore_training_state(self, checkpoint):
210217
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
211218
optimizer.load_state_dict(opt_state)
212219

220+
# move optimizer to GPU 1 weight at a time
221+
# avoids OOM
222+
if self.root_gpu is not None:
223+
for state in optimizer.state.values():
224+
for k, v in state.items():
225+
if isinstance(v, torch.Tensor):
226+
state[k] = v.cuda(self.root_gpu)
227+
213228
# restore the lr schedulers
214229
lr_schedulers = checkpoint['lr_schedulers']
215230
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
@@ -225,9 +240,6 @@ def hpc_save(self, folderpath, experiment):
225240
# save exp to make sure we get all the metrics
226241
experiment.save()
227242

228-
# close experiment to avoid issues
229-
experiment.close()
230-
231243
ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
232244

233245
if not os.path.exists(folderpath):
@@ -248,23 +260,26 @@ def hpc_save(self, folderpath, experiment):
248260
def hpc_load(self, folderpath, on_gpu):
249261
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))
250262

251-
if on_gpu:
252-
checkpoint = torch.load(filepath)
253-
else:
254-
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
255-
256-
# load training state (affects trainer only)
257-
self.restore_training_state(checkpoint)
263+
# load on CPU first
264+
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
258265

259266
# load model state
260267
model = self.__get_model()
261268

262269
# load the state_dict on the model automatically
263270
model.load_state_dict(checkpoint['state_dict'])
264271

272+
if self.root_gpu is not None:
273+
model.cuda(self.root_gpu)
274+
275+
# load training state (affects trainer only)
276+
self.restore_training_state(checkpoint)
277+
265278
# call model hook
266279
model.on_hpc_load(checkpoint)
267280

281+
print(f'restored hpc model from: {filepath}')
282+
268283
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
269284
files = os.listdir(path)
270285
files = [x for x in files if name_key in x]

tests/debug.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,20 @@ def get_hparams(continue_training=False, hpc_exp_number=0):
214214

215215

216216
def main():
217-
"""Verify test() on fitted model"""
217+
"""
218+
Make sure DDP + AMP continue training correctly
219+
:return:
220+
"""
218221
hparams = get_hparams()
219222
model = LightningTestModel(hparams)
220223

224+
trainer_options = dict(
225+
show_progress_bar=True,
226+
max_nb_epochs=4,
227+
gpus=2,
228+
distributed_backend='dp',
229+
)
230+
221231
save_dir = init_save_dir()
222232

223233
# exp file to get meta
@@ -228,31 +238,59 @@ def main():
228238
# exp file to get weights
229239
checkpoint = ModelCheckpoint(save_dir)
230240

231-
trainer_options = dict(
232-
show_progress_bar=False,
233-
max_nb_epochs=1,
234-
train_percent_check=0.4,
235-
val_percent_check=0.2,
236-
checkpoint_callback=checkpoint,
237-
experiment=exp,
238-
gpus=[0, 1],
239-
distributed_backend='ddp'
240-
)
241+
# add these to the trainer options
242+
trainer_options['experiment'] = exp
243+
trainer_options['checkpoint_callback'] = checkpoint
241244

242245
# fit model
243246
trainer = Trainer(**trainer_options)
247+
trainer.is_slurm_managing_tasks = True
244248
result = trainer.fit(model)
245249

250+
# track epoch before saving
251+
real_global_epoch = trainer.current_epoch
252+
246253
# correct result and ok accuracy
247-
assert result == 1, 'training failed to complete'
248-
pretrained_model = load_model(exp, save_dir, on_gpu=True, module_class=LightningTestModel)
254+
assert result == 1, 'amp + dp model failed to complete'
249255

256+
# ---------------------------
257+
# HPC LOAD/SAVE
258+
# ---------------------------
259+
# save
260+
trainer.hpc_save(save_dir, exp)
261+
262+
# init new trainer
263+
new_exp = get_exp(False, version=exp.version)
264+
trainer_options['experiment'] = new_exp
265+
trainer_options['checkpoint_callback'] = ModelCheckpoint(save_dir)
266+
trainer_options['train_percent_check'] = 0.2
267+
trainer_options['val_percent_check'] = 0.2
268+
trainer_options['max_nb_epochs'] = 1
250269
new_trainer = Trainer(**trainer_options)
251-
new_trainer.test(pretrained_model)
252270

253-
# test we have good test accuracy
254-
assert_ok_test_acc(new_trainer)
255-
# clear_save_dir()
271+
# set the epoch start hook so we can predict before the model does the full training
272+
def assert_good_acc():
273+
assert trainer.current_epoch == real_global_epoch and trainer.current_epoch > 0
274+
275+
# if model and state loaded correctly, predictions will be good even though we
276+
# haven't trained with the new loaded model
277+
dp_model = new_trainer.model
278+
dp_model.eval()
279+
280+
_ = [run_prediction(dataloader, dp_model, dp=True) for dataloader in trainer.val_dataloader]
281+
282+
# new model
283+
model = LightningTestModel(hparams)
284+
model.on_sanity_check_start = assert_good_acc
285+
286+
# fit new model which should load hpc weights
287+
new_trainer.fit(model)
288+
289+
# test freeze on gpu
290+
model.freeze()
291+
model.unfreeze()
292+
293+
clear_save_dir()
256294

257295

258296
if __name__ == '__main__':

0 commit comments

Comments
 (0)