diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 6436fc54b7bed..291daba7b4dc8 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -541,6 +541,7 @@ def fit( self.state.fn = TrainerFn.FITTING self.state.status = TrainerStatus.RUNNING self.training = True + self.should_stop = False call._call_and_handle_interrupt( self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path ) diff --git a/tests/tests_pytorch/loops/test_training_epoch_loop.py b/tests/tests_pytorch/loops/test_training_epoch_loop.py index 16ed3842e3a96..ffb0044cfdc54 100644 --- a/tests/tests_pytorch/loops/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/test_training_epoch_loop.py @@ -17,7 +17,7 @@ import pytest import torch from lightning.fabric.utilities.warnings import PossibleUserWarning -from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.trainer import Trainer from lightning_utilities.test.warning import no_warning_call @@ -90,7 +90,16 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, (min_epochs/steps is satisfied). """ - model = BoringModel() + + class NewBoring(BoringModel): + def training_step(self, batch, batch_idx): + self.log("loss", self.step(batch)) + return {"loss": self.step(batch)} + + model = NewBoring() + # create a stopping condition with a high threshold so it triggers immediately + # check the condition before validation so the count is unaffected + stopping = EarlyStopping(monitor="loss", check_on_train_epoch_end=True, stopping_threshold=100) trainer = Trainer( default_root_dir=tmp_path, num_sanity_val_steps=0, @@ -101,8 +110,8 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, min_steps=min_steps, enable_model_summary=False, enable_checkpointing=False, + callbacks=[stopping], ) - trainer.should_stop = True # Request to stop before min_epochs/min_steps are reached trainer.fit_loop.epoch_loop.val_loop.run = Mock() trainer.fit(model) assert trainer.fit_loop.epoch_loop.val_loop.run.call_count == val_count