Skip to content

Commit 6a11284

Browse files
Merge pull request #66 from williamFalcon/no_back
No back
2 parents b198435 + f3dea81 commit 6a11284

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

pytorch_lightning/models/trainer.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(self,
105105
:param log_save_interval:
106106
:param add_log_row_interval:
107107
:param distributed_backend:
108-
'np' to use DistributedParallel, 'dp' to use DistributedDataParallel
108+
'do' to use DistributedParallel, 'dp' to use DistributedDataParallel, 'n' to use none
109109
:param use_amp:
110110
:param print_nan_grads:
111111
:param print_weights_summary:
@@ -147,6 +147,7 @@ def __init__(self,
147147
self.node_rank = 0
148148
self.use_ddp = False
149149
self.use_dp = False
150+
self.single_gpu = False
150151

151152
# training bookeeping
152153
self.total_batch_nb = 0
@@ -194,6 +195,12 @@ def __init__(self,
194195
'To silence this warning set distributed_backend=ddp'
195196
warnings.warn(w)
196197

198+
# remove dp and ddp when requesting single gpu
199+
if self.data_parallel_device_ids is not None and len(self.data_parallel_device_ids) == 1:
200+
self.use_ddp = False
201+
self.use_dp = False
202+
self.single_gpu = True
203+
197204
# extract SLURM flag vars
198205
# whenever we have the correct number of tasks, we let slurm manage processes
199206
# otherwise we launch the required number of processes
@@ -385,6 +392,13 @@ def validate(self, model, dataloader, max_batches):
385392
output = model(data_batch, batch_i)
386393
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
387394

395+
elif self.single_gpu:
396+
gpu_id = self.data_parallel_device_ids[0]
397+
for i, x in enumerate(data_batch):
398+
if isinstance(x, torch.Tensor):
399+
data_batch[i] = x.cuda(gpu_id)
400+
output = model.validation_step(data_batch, batch_i)
401+
388402
else:
389403
output = model.validation_step(data_batch, batch_i)
390404

@@ -463,6 +477,9 @@ def fit(self, model):
463477
elif self.use_dp:
464478
self.__dp_train(model)
465479

480+
elif self.single_gpu:
481+
self.__single_gpu_train(model)
482+
466483
# ON CPU
467484
else:
468485
# run through amp wrapper
@@ -482,6 +499,24 @@ def fit(self, model):
482499
# used for testing or when we need to know that training succeeded
483500
return 1
484501

502+
def __single_gpu_train(self, model):
503+
# CHOOSE OPTIMIZER
504+
# allow for lr schedulers as well
505+
self.optimizers = model.configure_optimizers()
506+
if len(self.optimizers) == 2:
507+
self.optimizers, self.lr_schedulers = self.optimizers
508+
509+
model.cuda(self.data_parallel_device_ids[0])
510+
511+
if self.use_amp:
512+
# An example
513+
model, optimizers = amp.initialize(
514+
model, self.optimizers, opt_level=self.amp_level,
515+
)
516+
self.optimizers = optimizers
517+
518+
self.__run_pretrain_routine(model)
519+
485520
def __dp_train(self, model):
486521

487522
# CHOOSE OPTIMIZER
@@ -814,6 +849,13 @@ def __run_tng_batch(self, data_batch, batch_nb):
814849
elif self.use_dp:
815850
output = self.model(data_batch, batch_nb)
816851
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
852+
elif self.single_gpu:
853+
gpu_id = self.data_parallel_device_ids[0]
854+
for i, x in enumerate(data_batch):
855+
if isinstance(x, torch.Tensor):
856+
data_batch[i] = x.cuda(gpu_id)
857+
output = self.model.training_step(data_batch, batch_nb)
858+
817859
else:
818860
output = self.model.training_step(data_batch, batch_nb)
819861

tests/test_models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,34 @@
2727
# TESTS
2828
# ------------------------------------------------------------------------
2929

30+
def test_amp_single_gpu():
31+
"""
32+
Make sure DDP + AMP work
33+
:return:
34+
"""
35+
if not torch.cuda.is_available():
36+
warnings.warn('test_amp_gpu_ddp cannot run.'
37+
'Rerun on a GPU node to run this test')
38+
return
39+
if not torch.cuda.device_count() > 1:
40+
warnings.warn('test_amp_gpu_ddp cannot run.'
41+
'Rerun on a node with 2+ GPUs to run this test')
42+
return
43+
44+
hparams = get_hparams()
45+
model = LightningTestModel(hparams)
46+
47+
trainer_options = dict(
48+
progress_bar=True,
49+
max_nb_epochs=1,
50+
gpus=[0],
51+
distributed_backend='dp',
52+
use_amp=True
53+
)
54+
55+
run_gpu_model_test(trainer_options, model, hparams)
56+
57+
3058
def test_cpu_restore_training():
3159
"""
3260
Verify continue training session on CPU

0 commit comments

Comments
 (0)