1010
1111from torchtnt .framework .callback import Callback
1212from torchtnt .framework .state import State
13- from torchtnt .framework .unit import AppStateMixin , TTrainUnit
13+ from torchtnt .framework .unit import AppStateMixin , TEvalUnit , TTrainUnit
1414from torchtnt .utils .distributed import get_global_rank , sync_bool
1515from torchtnt .utils .early_stop_checker import EarlyStopChecker
1616
@@ -23,6 +23,7 @@ class EarlyStopping(Callback):
2323 monitored_attr: The attribute to monitor on the unit. Must be a float or tensor attribute on the unit.
2424 early_stop_checker: a :class:`~torchtnt.utils.early_stop_checker.EarlyStopChecker` to use for checking whether to stop early.
2525 interval: The interval to check the monitored attribute. Must be one of "step" or "epoch".
26+ phase: The phase to check the monitored attribute. Must be one of "train" or "eval".
2627
2728 Note:
2829 If doing distributed training, this callback checks the metric value only on rank 0
@@ -33,29 +34,49 @@ def __init__(
3334 monitored_attr : str ,
3435 early_stop_checker : EarlyStopChecker ,
3536 interval : Literal ["step" , "epoch" ] = "epoch" ,
37+ phase : Literal ["train" , "eval" ] = "train" ,
3638 interval_freq : int = 1 ,
3739 ) -> None :
3840 self ._monitored_attr = monitored_attr
3941 self ._esc = early_stop_checker
4042 self ._interval = interval
4143 self ._interval_freq = interval_freq
44+ self ._phase = phase
4245
4346 self ._rank : int = get_global_rank ()
4447
4548 def on_train_step_end (self , state : State , unit : TTrainUnit ) -> None :
4649 if (
47- self ._interval == "step"
50+ self ._phase == "train"
51+ and self ._interval == "step"
4852 and unit .train_progress .num_steps_completed % self ._interval_freq == 0
4953 ):
5054 self ._maybe_stop (state , unit )
5155
5256 def on_train_epoch_end (self , state : State , unit : TTrainUnit ) -> None :
5357 if (
54- self ._interval == "epoch"
58+ self ._phase == "train"
59+ and self ._interval == "epoch"
5560 and unit .train_progress .num_epochs_completed % self ._interval_freq == 0
5661 ):
5762 self ._maybe_stop (state , unit )
5863
64+ def on_eval_step_end (self , state : State , unit : TEvalUnit ) -> None :
65+ if (
66+ self ._phase == "eval"
67+ and self ._interval == "step"
68+ and unit .eval_progress .num_steps_completed % self ._interval_freq == 0
69+ ):
70+ self ._maybe_stop (state , unit )
71+
72+ def on_eval_epoch_end (self , state : State , unit : TEvalUnit ) -> None :
73+ if (
74+ self ._phase == "eval"
75+ and self ._interval == "epoch"
76+ and unit .eval_progress .num_epochs_completed % self ._interval_freq == 0
77+ ):
78+ self ._maybe_stop (state , unit )
79+
5980 def _maybe_stop (self , state : State , unit : AppStateMixin ) -> None :
6081 """
6182 Checks whether to stop early based on the monitored attribute.
0 commit comments