Skip to content

Commit 10d18d9

Browse files
committed
test: Check lr is updated at the end of epoch
When `on_train_batch_start` returns -1, the rest of the epoch is skipped. The lr update should still happen at the end of the epoch. - Test is_last_batch has been set correctly - Test lr has been updated at the end of each epoch
1 parent 60a7cd3 commit 10d18d9

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/tests_pytorch/loops/test_training_loop.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def on_train_batch_start(self, batch, batch_idx):
111111
assert trainer.fit_loop.batch_idx == batch_idx_
112112
assert trainer.global_step == batch_idx_ * max_epochs
113113

114+
assert trainer.is_last_batch
115+
114116

115117
def test_should_stop_mid_epoch(tmp_path):
116118
"""Test that training correctly stops mid epoch and that validation is still called at the right time."""
@@ -305,3 +307,26 @@ def test_eval_mode_warning(tmp_path, warn):
305307
w for w in warning_list if issubclass(w.category, PossibleUserWarning) and "eval mode" in str(w.message)
306308
]
307309
assert len(eval_warnings) == 0, "Expected no eval mode warnings"
310+
311+
312+
@pytest.mark.parametrize(("max_epochs", "batch_idx_"), [(2, 5), (3, 8)])
313+
def test_lr_updated_on_train_batch_start_returns_minus_one(tmp_path, max_epochs, batch_idx_):
314+
"""Test that when the rest of the epoch is skipped, due to on_train_batch_start returning -1, the learning rate is
315+
still updated when it should, at the end of the epoch."""
316+
317+
class TestModel(BoringModel):
318+
def on_train_batch_start(self, batch, batch_idx):
319+
if batch_idx == batch_idx_:
320+
return -1
321+
return super().on_train_batch_start(batch, batch_idx)
322+
323+
model = TestModel()
324+
init_lr = 0.1
325+
trainer = Trainer(default_root_dir=tmp_path, limit_train_batches=10, max_epochs=max_epochs)
326+
trainer.fit(model)
327+
328+
adjusted_lr = [pg["lr"] for pg in trainer.optimizers[0].param_groups]
329+
330+
assert len(trainer.lr_scheduler_configs) == 1
331+
assert all(a == adjusted_lr[0] for a in adjusted_lr)
332+
assert init_lr * 0.1**max_epochs == adjusted_lr[0]

0 commit comments

Comments
 (0)