Skip to content

Commit 17d50a2

Browse files
committed
add testing
1 parent 46d8fa2 commit 17d50a2

File tree

1 file changed

+189
-0
lines changed

1 file changed

+189
-0
lines changed

tests/tests_pytorch/callbacks/test_early_stopping.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
1415
import logging
1516
import math
1617
import os
@@ -25,6 +26,7 @@
2526

2627
from lightning.pytorch import Trainer, seed_everything
2728
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
29+
from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason
2830
from lightning.pytorch.demos.boring_classes import BoringModel
2931
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3032
from tests_pytorch.helpers.datamodules import ClassifDataModule
@@ -505,3 +507,190 @@ def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, ex
505507
log_mock.assert_called_once_with(expected_log)
506508
else:
507509
log_mock.assert_not_called()
510+
511+
512+
class ModelWithHighLoss(BoringModel):
513+
def on_validation_epoch_end(self):
514+
self.log("val_loss", 10.0)
515+
516+
517+
class ModelWithDecreasingLoss(BoringModel):
518+
def __init__(self):
519+
super().__init__()
520+
self.epoch_losses = [5.0, 3.0, 1.0, 0.5]
521+
522+
def on_validation_epoch_end(self):
523+
loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 0.1
524+
self.log("val_loss", loss)
525+
526+
527+
class ModelWithIncreasingLoss(BoringModel):
528+
def __init__(self):
529+
super().__init__()
530+
self.epoch_losses = [1.0, 2.0, 5.0, 10.0]
531+
532+
def on_validation_epoch_end(self):
533+
loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 15.0
534+
self.log("val_loss", loss)
535+
536+
537+
class ModelWithNaNLoss(BoringModel):
538+
def __init__(self):
539+
super().__init__()
540+
self.epoch_losses = [1.0, 0.5, float("nan")]
541+
542+
def on_validation_epoch_end(self):
543+
loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else float("nan")
544+
self.log("val_loss", loss)
545+
546+
547+
class ModelWithImprovingLoss(BoringModel):
548+
def __init__(self):
549+
super().__init__()
550+
self.epoch_losses = [5.0, 4.0, 3.0, 2.0, 1.0]
551+
552+
def on_validation_epoch_end(self):
553+
loss = self.epoch_losses[self.current_epoch] if self.current_epoch < len(self.epoch_losses) else 0.1
554+
self.log("val_loss", loss)
555+
556+
557+
@pytest.mark.parametrize(
558+
(
559+
"model_cls",
560+
"early_stopping_kwargs",
561+
"trainer_kwargs",
562+
"expected_reason",
563+
"reason_message_substr",
564+
"should_stop",
565+
"state_dict_override",
566+
),
567+
[
568+
# Patience exhausted
569+
(
570+
ModelWithHighLoss,
571+
{"monitor": "val_loss", "patience": 2, "verbose": True},
572+
{"max_epochs": 10, "enable_progress_bar": False},
573+
EarlyStoppingReason.PATIENCE_EXHAUSTED,
574+
"did not improve",
575+
True,
576+
None,
577+
),
578+
# Stopping threshold
579+
(
580+
ModelWithDecreasingLoss,
581+
{"monitor": "val_loss", "stopping_threshold": 0.6, "mode": "min", "verbose": True},
582+
{"max_epochs": 10, "enable_progress_bar": False},
583+
EarlyStoppingReason.STOPPING_THRESHOLD,
584+
"Stopping threshold reached",
585+
True,
586+
None,
587+
),
588+
# Divergence threshold
589+
(
590+
ModelWithIncreasingLoss,
591+
{"monitor": "val_loss", "divergence_threshold": 8.0, "mode": "min", "verbose": True},
592+
{"max_epochs": 10, "enable_progress_bar": False},
593+
EarlyStoppingReason.DIVERGENCE_THRESHOLD,
594+
"Divergence threshold reached",
595+
True,
596+
None,
597+
),
598+
# Non-finite metric
599+
(
600+
ModelWithNaNLoss,
601+
{"monitor": "val_loss", "check_finite": True, "verbose": True},
602+
{"max_epochs": 10, "enable_progress_bar": False},
603+
EarlyStoppingReason.NON_FINITE_METRIC,
604+
"is not finite",
605+
True,
606+
None,
607+
),
608+
# Not stopped (normal completion)
609+
(
610+
ModelWithImprovingLoss,
611+
{"monitor": "val_loss", "patience": 3, "verbose": True},
612+
{"max_epochs": 3, "enable_progress_bar": False},
613+
EarlyStoppingReason.NOT_STOPPED,
614+
None,
615+
False,
616+
None,
617+
),
618+
# State persistence
619+
(
620+
None,
621+
{"monitor": "val_loss", "patience": 3},
622+
{},
623+
EarlyStoppingReason.PATIENCE_EXHAUSTED,
624+
"Test message",
625+
None,
626+
{"stopping_reason": EarlyStoppingReason.PATIENCE_EXHAUSTED, "stopping_reason_message": "Test message"},
627+
),
628+
# Backward compatibility (old state dict)
629+
(
630+
None,
631+
{"monitor": "val_loss", "patience": 3},
632+
{},
633+
EarlyStoppingReason.NOT_STOPPED,
634+
None,
635+
None,
636+
{
637+
"wait_count": 2,
638+
"stopped_epoch": 5,
639+
"best_score": torch.tensor(0.5),
640+
"patience": 3,
641+
},
642+
),
643+
],
644+
)
645+
def test_early_stopping_reasons(
646+
tmp_path,
647+
model_cls,
648+
early_stopping_kwargs,
649+
trainer_kwargs,
650+
expected_reason,
651+
reason_message_substr,
652+
should_stop,
653+
state_dict_override,
654+
):
655+
"""Test all early stopping reasons in a single parametrized test."""
656+
if state_dict_override is not None:
657+
early_stopping = EarlyStopping(**early_stopping_kwargs)
658+
if "stopping_reason" in state_dict_override:
659+
# State persistence test
660+
early_stopping.stopping_reason = state_dict_override["stopping_reason"]
661+
early_stopping.stopping_reason_message = state_dict_override["stopping_reason_message"]
662+
state_dict = early_stopping.state_dict()
663+
new_early_stopping = EarlyStopping(**early_stopping_kwargs)
664+
new_early_stopping.load_state_dict(state_dict)
665+
assert new_early_stopping.stopping_reason == expected_reason
666+
assert new_early_stopping.stopping_reason_message == reason_message_substr
667+
else:
668+
# Backward compatibility test
669+
early_stopping.load_state_dict(copy.deepcopy(state_dict_override))
670+
assert early_stopping.stopping_reason == expected_reason
671+
assert early_stopping.stopping_reason_message is None
672+
assert early_stopping.wait_count == state_dict_override["wait_count"]
673+
assert early_stopping.stopped_epoch == state_dict_override["stopped_epoch"]
674+
return
675+
676+
# All other tests
677+
model = model_cls()
678+
early_stopping = EarlyStopping(**early_stopping_kwargs)
679+
trainer = Trainer(
680+
default_root_dir=tmp_path,
681+
callbacks=[early_stopping],
682+
**trainer_kwargs,
683+
)
684+
trainer.fit(model)
685+
686+
assert early_stopping.stopping_reason == expected_reason
687+
if reason_message_substr is not None:
688+
assert early_stopping.stopping_reason_message is not None
689+
assert reason_message_substr in early_stopping.stopping_reason_message
690+
else:
691+
assert early_stopping.stopping_reason_message is None
692+
if should_stop is not None:
693+
if should_stop:
694+
assert early_stopping.stopped_epoch > 0
695+
else:
696+
assert early_stopping.stopped_epoch == 0

0 commit comments

Comments
 (0)