diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 8bb123939dc20..989fbb298e3dd 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -182,6 +182,11 @@ def done(self) -> bool: rank_zero_info(f"`Trainer.fit` stopped: `max_steps={self.max_steps!r}` reached.") return True + # Check early stopping before max_epochs to prioritize user-initiated stopping + if self.trainer.should_stop and self._can_stop_early: + rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.") + return True + # `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved. # we use it here because the checkpoint data won't have `completed` increased yet assert isinstance(self.max_epochs, int) @@ -192,10 +197,6 @@ def done(self) -> bool: rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.") return True - if self.trainer.should_stop and self._can_stop_early: - rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.") - return True - return False @property diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index ff65e08c8c01a..0be089899737e 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -509,6 +509,51 @@ def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, ex log_mock.assert_not_called() +def test_early_stopping_message_priority_over_max_epochs(caplog): + """Test that early stopping message takes priority over max_epochs message when both conditions are met.""" + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.validation_called = False + + def validation_step(self, batch, batch_idx): + result = super().validation_step(batch, batch_idx) + self.log("monitor_metric", 1000.0 if not self.validation_called else 999.0) + self.validation_called = True + return result + + def on_validation_end(self): + if self.current_epoch == 0: + self.trainer.should_stop = True + + model = TestModel() + early_stopping = EarlyStopping( + monitor="monitor_metric", + mode="min", + patience=0, + verbose=True, + ) + + trainer = Trainer( + max_epochs=1, + callbacks=[early_stopping], + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + limit_train_batches=1, + limit_val_batches=1, + ) + + with caplog.at_level(logging.DEBUG, logger="lightning.pytorch.utilities.rank_zero"): + trainer.fit(model) + + assert "`Trainer.fit` stopped: `trainer.should_stop` was set." in caplog.text + assert trainer.should_stop is True + assert early_stopping.stopped_epoch >= 0 + assert "`Trainer.fit` stopped: `max_epochs=1` reached." not in caplog.text + + class ModelWithHighLoss(BoringModel): def on_validation_epoch_end(self): self.log("val_loss", 10.0) diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index f5aaa18095fc5..a3eb76386b1a2 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -166,16 +166,19 @@ def test_fit_loop_done_log_messages(caplog): fit_loop.epoch_progress.current.processed = 3 fit_loop.max_epochs = 3 trainer.should_stop = True - assert fit_loop.done - assert "max_epochs=3` reached" in caplog.text - caplog.clear() - fit_loop.max_epochs = 5 - fit_loop.epoch_loop.min_steps = 0 with caplog.at_level(level=logging.DEBUG, logger="lightning.pytorch.utilities.rank_zero"): assert fit_loop.done assert "should_stop` was set" in caplog.text + caplog.clear() + trainer.should_stop = False + assert fit_loop.done + assert "max_epochs=3` reached" in caplog.text + caplog.clear() + fit_loop.max_epochs = 5 + + trainer.should_stop = True fit_loop.epoch_loop.min_steps = 100 assert not fit_loop.done @@ -183,7 +186,7 @@ def test_fit_loop_done_log_messages(caplog): @pytest.mark.parametrize( ("min_epochs", "min_steps", "current_epoch", "early_stop", "fit_loop_done", "raise_debug_msg"), [ - (4, None, 100, True, True, False), + (4, None, 100, True, True, True), (4, None, 3, False, False, False), (4, 10, 3, False, False, False), (None, 10, 4, True, True, True),