1919from  lightning_utilities .test .warning  import  no_warning_call 
2020
2121from  lightning .fabric .utilities .warnings  import  PossibleUserWarning 
22- from  lightning .pytorch .callbacks  import  ModelCheckpoint 
22+ from  lightning .pytorch .callbacks  import  EarlyStopping ,  ModelCheckpoint 
2323from  lightning .pytorch .demos .boring_classes  import  BoringModel 
2424from  lightning .pytorch .trainer .trainer  import  Trainer 
2525
@@ -92,7 +92,16 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count,
9292    (min_epochs/steps is satisfied). 
9393
9494    """ 
95-     model  =  BoringModel ()
95+ 
96+     class  NewBoring (BoringModel ):
97+         def  training_step (self , batch , batch_idx ):
98+             self .log ("loss" , self .step (batch ))
99+             return  {"loss" : self .step (batch )}
100+ 
101+     model  =  NewBoring ()
102+     # create a stopping condition with a high threshold so it triggers immediately 
103+     # check the condition before validation so the count is unaffected 
104+     stopping  =  EarlyStopping (monitor = "loss" , check_on_train_epoch_end = True , stopping_threshold = 100 )
96105    trainer  =  Trainer (
97106        default_root_dir = tmp_path ,
98107        num_sanity_val_steps = 0 ,
@@ -103,8 +112,8 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count,
103112        min_steps = min_steps ,
104113        enable_model_summary = False ,
105114        enable_checkpointing = False ,
115+         callbacks = [stopping ],
106116    )
107-     trainer .should_stop  =  True   # Request to stop before min_epochs/min_steps are reached 
108117    trainer .fit_loop .epoch_loop .val_loop .run  =  Mock ()
109118    trainer .fit (model )
110119    assert  trainer .fit_loop .epoch_loop .val_loop .run .call_count  ==  val_count 
0 commit comments