Skip to content

Commit c7e8436

Browse files
committed
added single gpu train test
1 parent afa4548 commit c7e8436

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

tests/test_models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,34 @@
2727
# TESTS
2828
# ------------------------------------------------------------------------
2929

30+
def test_amp_single_gpu():
31+
"""
32+
Make sure DDP + AMP work
33+
:return:
34+
"""
35+
if not torch.cuda.is_available():
36+
warnings.warn('test_amp_gpu_ddp cannot run.'
37+
'Rerun on a GPU node to run this test')
38+
return
39+
if not torch.cuda.device_count() > 1:
40+
warnings.warn('test_amp_gpu_ddp cannot run.'
41+
'Rerun on a node with 2+ GPUs to run this test')
42+
return
43+
44+
hparams = get_hparams()
45+
model = LightningTestModel(hparams)
46+
47+
trainer_options = dict(
48+
progress_bar=True,
49+
max_nb_epochs=1,
50+
gpus=[0],
51+
distributed_backend='dp',
52+
use_amp=True
53+
)
54+
55+
run_gpu_model_test(trainer_options, model, hparams)
56+
57+
3058
def test_cpu_restore_training():
3159
"""
3260
Verify continue training session on CPU

0 commit comments

Comments
 (0)