@@ -399,7 +399,7 @@ def test_model_freeze_unfreeze():
399
399
assert param .requires_grad
400
400
401
401
402
- @ pytest . mark . xfail ( reason = "FIXME(@carmocca): this test wasn't running and is now broken" )
402
+ # TODO: move to `test/models/test_restore.py`
403
403
@pytest .mark .parametrize ("url_ckpt" , [True , False ])
404
404
def test_fit_ckpt_path_epoch_restored (monkeypatch , tmpdir , tmpdir_server , url_ckpt ):
405
405
"""Verify resuming from checkpoint runs the right number of epochs."""
@@ -422,11 +422,12 @@ def on_load_checkpoint(self, _):
422
422
self .num_on_load_checkpoint_called += 1
423
423
424
424
model = TestModel ()
425
+ max_epochs = 2
425
426
trainer = Trainer (
426
- max_epochs = 2 ,
427
+ max_epochs = max_epochs ,
427
428
limit_train_batches = 0.65 ,
428
429
limit_val_batches = 1 ,
429
- callbacks = [ ModelCheckpoint (dirpath = tmpdir , save_top_k = - 1 )] ,
430
+ callbacks = ModelCheckpoint (dirpath = tmpdir , save_top_k = - 1 ),
430
431
default_root_dir = tmpdir ,
431
432
val_check_interval = 1.0 ,
432
433
enable_progress_bar = False ,
@@ -435,27 +436,25 @@ def on_load_checkpoint(self, _):
435
436
)
436
437
trainer .fit (model )
437
438
438
- assert model .num_epochs_end_seen == 2
439
- assert model .num_batches_seen == trainer .num_training_batches * 2
439
+ assert model .num_epochs_end_seen == max_epochs
440
+ assert model .num_batches_seen == trainer .num_training_batches * max_epochs == trainer . global_step
440
441
assert model .num_on_load_checkpoint_called == 0
441
442
442
- # Other checkpoints can be uncommented if/when resuming mid-epoch is supported
443
- checkpoints = Path (trainer .checkpoint_callback .dirpath ).glob ("*.ckpt" )
443
+ checkpoints = set (Path (trainer .checkpoint_callback .dirpath ).glob ("*.ckpt" ))
444
444
if url_ckpt :
445
445
# transform local paths into url checkpoints
446
446
ip , port = tmpdir_server
447
447
checkpoints = [f"http://{ ip } :{ port } /" + ckpt .name for ckpt in checkpoints ]
448
448
449
- assert checkpoints
449
+ assert len ( checkpoints ) == max_epochs
450
450
for ckpt in checkpoints :
451
- next_model = TestModel ()
451
+ model = TestModel ()
452
452
state = pl_load (ckpt )
453
-
454
453
# Resume training
455
- new_trainer = Trainer (default_root_dir = tmpdir , max_epochs = 2 )
456
- new_trainer .fit (next_model , ckpt_path = ckpt )
457
- assert state ["global_step" ] + next_model .num_batches_seen == trainer .num_training_batches * trainer . max_epochs
458
- assert next_model .num_on_load_checkpoint_called == 1
454
+ trainer = Trainer (default_root_dir = tmpdir , max_epochs = 2 , enable_progress_bar = False )
455
+ trainer .fit (model , ckpt_path = ckpt )
456
+ assert state ["global_step" ] + model .num_batches_seen == trainer .global_step
457
+ assert model .num_on_load_checkpoint_called == 1
459
458
460
459
461
460
def test_trainer_max_steps_and_epochs (tmpdir ):
0 commit comments