File tree Expand file tree Collapse file tree 2 files changed +20
-0
lines changed Expand file tree Collapse file tree 2 files changed +20
-0
lines changed Original file line number Diff line number Diff line change @@ -53,6 +53,16 @@ def test_active_phase_into_phase(self) -> None:
5353 predict_phase = ActivePhase .PREDICT
5454 self .assertEqual (predict_phase .into_phase (), Phase .PREDICT )
5555
56+ def test_active_phase_str (self ) -> None :
57+ active_phase = ActivePhase .TRAIN
58+ self .assertEqual (str (active_phase ), "train" )
59+
60+ eval_phase = ActivePhase .EVALUATE
61+ self .assertEqual (str (eval_phase ), "eval" )
62+
63+ predict_phase = ActivePhase .PREDICT
64+ self .assertEqual (str (predict_phase ), "predict" )
65+
5666 def test_set_evaluate_every_n_steps_or_epochs (self ) -> None :
5767 state = PhaseState (dataloader = [], evaluate_every_n_steps = 2 )
5868 state .evaluate_every_n_steps = None
Original file line number Diff line number Diff line change @@ -74,6 +74,16 @@ def into_phase(self) -> Phase:
7474 else :
7575 raise AssertionError ("Should match an ActivePhase" )
7676
77+ def __str__ (self ) -> str :
78+ if self == ActivePhase .TRAIN :
79+ return "train"
80+ elif self == ActivePhase .EVALUATE :
81+ return "eval"
82+ elif self == ActivePhase .PREDICT :
83+ return "predict"
84+ else :
85+ raise AssertionError ("Should match an ActivePhase" )
86+
7787
7888class PhaseState (Generic [TData , TStepOutput ]):
7989 """State for each phase (train, eval, predict).
You can’t perform that action at this time.
0 commit comments