Skip to content

Commit aa7245d

Browse files
Load fix (#74)
* skip weight load without callback * added simple cpu test * fixed pep
1 parent a8abeaf commit aa7245d

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

pytorch_lightning/models/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,11 @@ def restore_state_if_existing_checkpoint(self):
259259
last_epoch = -1
260260
last_ckpt_name = None
261261

262+
# do nothing if there's not dir or callback
263+
no_ckpt_callback = self.checkpoint_callback is None
264+
if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath):
265+
return
266+
262267
# find last epoch
263268
checkpoints = os.listdir(self.checkpoint_callback.filepath)
264269
for name in checkpoints:

tests/test_models.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,38 @@
2626
# ------------------------------------------------------------------------
2727
# TESTS
2828
# ------------------------------------------------------------------------
29+
def test_simple_cpu():
30+
"""
31+
Verify continue training session on CPU
32+
:return:
33+
"""
34+
hparams = get_hparams()
35+
model = LightningTestModel(hparams)
36+
37+
save_dir = init_save_dir()
38+
39+
# exp file to get meta
40+
test_exp_version = 10
41+
exp = get_exp(False, version=test_exp_version)
42+
exp.argparse(hparams)
43+
exp.save()
44+
45+
trainer_options = dict(
46+
max_nb_epochs=1,
47+
val_percent_check=0.1,
48+
train_percent_check=0.1,
49+
experiment=exp,
50+
)
51+
52+
# fit model
53+
trainer = Trainer(**trainer_options)
54+
result = trainer.fit(model)
55+
56+
# traning complete
57+
assert result == 1, 'amp + ddp model failed to complete'
58+
59+
clear_save_dir()
60+
2961

3062
def test_amp_single_gpu():
3163
"""

0 commit comments

Comments
 (0)