Skip to content

Commit 40c682e

Browse files
Reset trainer variable should_stop when fit is called (#19177)
--------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 96df8c4 commit 40c682e

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

src/lightning/pytorch/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ def fit(
542542
self.state.fn = TrainerFn.FITTING
543543
self.state.status = TrainerStatus.RUNNING
544544
self.training = True
545+
self.should_stop = False
545546
call._call_and_handle_interrupt(
546547
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
547548
)

tests/tests_pytorch/loops/test_training_epoch_loop.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from lightning_utilities.test.warning import no_warning_call
2020

2121
from lightning.fabric.utilities.warnings import PossibleUserWarning
22-
from lightning.pytorch.callbacks import ModelCheckpoint
22+
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
2323
from lightning.pytorch.demos.boring_classes import BoringModel
2424
from lightning.pytorch.trainer.trainer import Trainer
2525

@@ -92,7 +92,16 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count,
9292
(min_epochs/steps is satisfied).
9393
9494
"""
95-
model = BoringModel()
95+
96+
class NewBoring(BoringModel):
97+
def training_step(self, batch, batch_idx):
98+
self.log("loss", self.step(batch))
99+
return {"loss": self.step(batch)}
100+
101+
model = NewBoring()
102+
# create a stopping condition with a high threshold so it triggers immediately
103+
# check the condition before validation so the count is unaffected
104+
stopping = EarlyStopping(monitor="loss", check_on_train_epoch_end=True, stopping_threshold=100)
96105
trainer = Trainer(
97106
default_root_dir=tmp_path,
98107
num_sanity_val_steps=0,
@@ -103,8 +112,8 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count,
103112
min_steps=min_steps,
104113
enable_model_summary=False,
105114
enable_checkpointing=False,
115+
callbacks=[stopping],
106116
)
107-
trainer.should_stop = True # Request to stop before min_epochs/min_steps are reached
108117
trainer.fit_loop.epoch_loop.val_loop.run = Mock()
109118
trainer.fit(model)
110119
assert trainer.fit_loop.epoch_loop.val_loop.run.call_count == val_count

0 commit comments

Comments
 (0)