19
19
from lightning_utilities .test .warning import no_warning_call
20
20
21
21
from lightning .fabric .utilities .warnings import PossibleUserWarning
22
- from lightning .pytorch .callbacks import ModelCheckpoint
22
+ from lightning .pytorch .callbacks import EarlyStopping , ModelCheckpoint
23
23
from lightning .pytorch .demos .boring_classes import BoringModel
24
24
from lightning .pytorch .trainer .trainer import Trainer
25
25
@@ -92,7 +92,16 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count,
92
92
(min_epochs/steps is satisfied).
93
93
94
94
"""
95
- model = BoringModel ()
95
+
96
+ class NewBoring (BoringModel ):
97
+ def training_step (self , batch , batch_idx ):
98
+ self .log ("loss" , self .step (batch ))
99
+ return {"loss" : self .step (batch )}
100
+
101
+ model = NewBoring ()
102
+ # create a stopping condition with a high threshold so it triggers immediately
103
+ # check the condition before validation so the count is unaffected
104
+ stopping = EarlyStopping (monitor = "loss" , check_on_train_epoch_end = True , stopping_threshold = 100 )
96
105
trainer = Trainer (
97
106
default_root_dir = tmp_path ,
98
107
num_sanity_val_steps = 0 ,
@@ -103,8 +112,8 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count,
103
112
min_steps = min_steps ,
104
113
enable_model_summary = False ,
105
114
enable_checkpointing = False ,
115
+ callbacks = [stopping ],
106
116
)
107
- trainer .should_stop = True # Request to stop before min_epochs/min_steps are reached
108
117
trainer .fit_loop .epoch_loop .val_loop .run = Mock ()
109
118
trainer .fit (model )
110
119
assert trainer .fit_loop .epoch_loop .val_loop .run .call_count == val_count
0 commit comments