1111
1212from torchtnt .framework .callbacks .checkpointer_types import RestoreOptions
1313from torchtnt .framework .state import EntryPoint , State
14- from torchtnt .framework .unit import AppStateMixin , TEvalUnit , TTrainUnit
14+ from torchtnt .framework .unit import AppStateMixin , TEvalUnit , TPredictUnit , TTrainUnit
1515from torchtnt .utils .checkpoint import Phase
1616
1717from torchtnt .utils .stateful import Stateful
3131
3232
3333def _get_step_phase_mapping (
34- state : State , unit : Union [TTrainUnit , TEvalUnit ]
34+ state : State , unit : Union [TTrainUnit , TEvalUnit , TPredictUnit ]
3535) -> Dict [Phase , int ]:
3636 """
3737 Returns a mapping of phase to step, depending on the entrypoint.
@@ -47,9 +47,32 @@ def _get_step_phase_mapping(
4747 eval_unit = cast (TEvalUnit , unit )
4848 step_mapping [Phase .EVALUATE ] = eval_unit .eval_progress .num_steps_completed
4949
50+ if state .entry_point == EntryPoint .PREDICT :
51+ predict_unit = cast (TPredictUnit , unit )
52+ step_mapping [Phase .PREDICT ] = predict_unit .predict_progress .num_steps_completed
53+
5054 return step_mapping
5155
5256
57+ def _get_epoch (state : State , unit : Union [TTrainUnit , TEvalUnit , TPredictUnit ]) -> int :
58+ """
59+ Returns the epoch depending on the entrypoint. For FIT, it always returns the train epoch.
60+ """
61+ if state .entry_point in (EntryPoint .TRAIN , EntryPoint .FIT ):
62+ train_unit = cast (TTrainUnit , unit )
63+ return train_unit .train_progress .num_epochs_completed
64+
65+ elif state .entry_point == EntryPoint .PREDICT :
66+ predict_unit = cast (TPredictUnit , unit )
67+ return predict_unit .predict_progress .num_epochs_completed
68+
69+ elif state .entry_point == EntryPoint .EVALUATE :
70+ eval_unit = cast (TEvalUnit , unit )
71+ return eval_unit .eval_progress .num_epochs_completed
72+
73+ raise ValueError (f"Unknown entrypoint: { state .entry_point } " )
74+
75+
5376def _prepare_app_state (unit : AppStateMixin ) -> Dict [str , Any ]:
5477 """Join together all of the tracked stateful entities to simplify registration of snapshottable states, deals with FSDP case"""
5578 app_state = unit .app_state ()
0 commit comments