1515import torch .distributed as dist
1616from pyre_extensions import none_throws
1717from torchtnt .framework .callback import Callback
18- from torchtnt .framework .callbacks ._checkpoint_utils import _get_step_phase_mapping
18+ from torchtnt .framework .callbacks ._checkpoint_utils import (
19+ _get_epoch ,
20+ _get_step_phase_mapping ,
21+ )
1922from torchtnt .framework .callbacks .checkpointer_types import RestoreOptions
20- from torchtnt .framework .state import EntryPoint , State
21- from torchtnt .framework .unit import AppStateMixin , TEvalUnit , TTrainData , TTrainUnit
23+ from torchtnt .framework .state import ActivePhase , EntryPoint , State
24+ from torchtnt .framework .unit import (
25+ AppStateMixin ,
26+ TEvalUnit ,
27+ TPredictUnit ,
28+ TTrainData ,
29+ TTrainUnit ,
30+ )
2231from torchtnt .utils .checkpoint import (
2332 BestCheckpointConfig ,
2433 CheckpointManager ,
@@ -51,8 +60,11 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
5160 save_every_n_train_steps: Frequency of steps with which to save checkpoints during the train epoch. If None, no intra-epoch checkpoints are generated.
5261 save_every_n_epochs: Frequency of epochs with which to save checkpoints during training. If None, no end-of-epoch checkpoints are generated.
5362 save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit.
54- keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
55- best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
63+ save_every_n_eval_steps: Frequency of evaluation steps with which to save checkpoints during training. Use this if wanting to save checkpoints during evaluate.
64+ save_every_n_predict_steps: Frequency of prediction steps with which to save checkpoints during training. Use this if wanting to save checkpoints during using predict entrypoint.
65+ keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted
66+ to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead. Only supported for train or fit entrypoints.
67+ best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint. This param is ignored if not in train or fit entrypoints.
5668 process_group: The process group on which the ranks will communicate on. If the process group is not gloo-based, a new gloo-based process group will be created.
5769
5870 Note:
@@ -78,6 +90,8 @@ def __init__(
7890 save_every_n_train_steps : Optional [int ] = None ,
7991 save_every_n_epochs : Optional [int ] = None ,
8092 save_every_n_eval_epochs : Optional [int ] = None ,
93+ save_every_n_eval_steps : Optional [int ] = None ,
94+ save_every_n_predict_steps : Optional [int ] = None ,
8195 keep_last_n_checkpoints : Optional [int ] = None ,
8296 best_checkpoint_config : Optional [BestCheckpointConfig ] = None ,
8397 process_group : Optional [dist .ProcessGroup ] = None ,
@@ -90,12 +104,23 @@ def __init__(
90104 raise ValueError (
91105 f"Invalid value passed for save_every_n_epochs. Expected to receive either None or positive number, but received { save_every_n_epochs } "
92106 )
107+ if save_every_n_eval_steps is not None and save_every_n_eval_steps <= 0 :
108+ raise ValueError (
109+ f"Invalid value passed for save_every_n_eval_steps. Expected to receive either None or positive number, but received { save_every_n_eval_steps } "
110+ )
111+ if save_every_n_eval_epochs is not None and save_every_n_eval_epochs <= 0 :
112+ raise ValueError (
113+ f"Invalid value passed for save_every_n_eval_epochs. Expected to receive either None or positive number, but received { save_every_n_eval_epochs } "
114+ )
115+ if save_every_n_predict_steps is not None and save_every_n_predict_steps <= 0 :
116+ raise ValueError (
117+ f"Invalid value passed for save_every_n_predict_steps. Expected to receive either None or positive number, but received { save_every_n_predict_steps } "
118+ )
93119 if keep_last_n_checkpoints is not None and keep_last_n_checkpoints <= 0 :
94120 raise ValueError (
95121 f"Invalid value passed for keep_last_n_checkpoints. Expected to receive either None or positive number, but received { keep_last_n_checkpoints } "
96122 )
97123
98- self ._best_checkpoint_config = best_checkpoint_config
99124 if best_checkpoint_config and best_checkpoint_config .mode not in {"min" , "max" }:
100125 raise ValueError (
101126 f"Invalid value passed for best_checkpoint_config.mode. Expected to receive 'min' or 'max', but received { best_checkpoint_config .mode } "
@@ -104,7 +129,10 @@ def __init__(
104129 self ._save_every_n_train_steps = save_every_n_train_steps
105130 self ._save_every_n_epochs = save_every_n_epochs
106131 self ._save_every_n_eval_epochs = save_every_n_eval_epochs
132+ self ._save_every_n_eval_steps = save_every_n_eval_steps
133+ self ._save_every_n_predict_steps = save_every_n_predict_steps
107134 self ._keep_last_n_checkpoints = keep_last_n_checkpoints
135+ self ._best_checkpoint_config = best_checkpoint_config
108136
109137 self ._process_group : Optional [dist .ProcessGroup ] = None
110138 self ._setup_gloo_pg (process_group )
@@ -147,7 +175,7 @@ def dirpath(self) -> str:
147175 return self ._checkpoint_manager .dirpath
148176
149177 def _generate_checkpoint_and_upkeep (
150- self , state : State , unit : Union [TTrainUnit , TEvalUnit ], hook : str
178+ self , state : State , unit : Union [TTrainUnit , TEvalUnit , TPredictUnit ], hook : str
151179 ) -> bool :
152180 """
153181 Implementation for saving checkpoint while taking care of checkpoint
@@ -162,11 +190,16 @@ def _generate_checkpoint_and_upkeep(
162190 True if checkpoint was successfully saved. False otherwise.
163191 """
164192 # 1) generate checkpoint name
165- epoch = cast ( TTrainUnit , unit ). train_progress . num_epochs_completed
193+ epoch = _get_epoch ( state , unit )
166194 step_mapping = _get_step_phase_mapping (state , unit )
167195
196+ # 1.1) append metric data only for train checkpoints, if best_checkpoint_config is defined
168197 metric_data : Optional [MetricData ] = None
169- if metric_value := self ._get_tracked_metric_value (unit ):
198+ if (
199+ self ._best_checkpoint_config
200+ and state .active_phase == ActivePhase .TRAIN
201+ and (metric_value := self ._get_tracked_metric_value (cast (TTrainUnit , unit )))
202+ ):
170203 metric_data = MetricData (
171204 name = none_throws (self ._best_checkpoint_config ).monitored_metric ,
172205 value = metric_value ,
@@ -179,7 +212,8 @@ def _generate_checkpoint_and_upkeep(
179212 process_group = self ._process_group ,
180213 )
181214
182- # 2) Determine if we should save checkpoint
215+ # 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints
216+ # since neither best_checkpoint_config nor keep_last_n_checkpoints are supported.
183217 if not self ._checkpoint_manager .should_save_checkpoint (checkpoint_path ):
184218 return False
185219
@@ -222,9 +256,7 @@ def _generate_checkpoint_and_upkeep(
222256
223257 return True
224258
225- def _get_tracked_metric_value (
226- self , unit : Union [TTrainUnit , TEvalUnit ]
227- ) -> Optional [float ]:
259+ def _get_tracked_metric_value (self , unit : TTrainUnit ) -> Optional [float ]:
228260 """
229261 If the checkpointer has a tracked metric, look the value in the unit using reflection, and cast to float.
230262
@@ -271,33 +303,80 @@ def on_train_start(self, state: State, unit: TTrainUnit) -> None:
271303
272304 def on_train_step_end (self , state : State , unit : TTrainUnit ) -> None :
273305 num_steps_completed = unit .train_progress .num_steps_completed
274- save_every_n_train_steps = self ._save_every_n_train_steps
275306 if (
276- save_every_n_train_steps is None
277- or num_steps_completed % save_every_n_train_steps != 0
307+ not self . _save_every_n_train_steps
308+ or num_steps_completed % self . _save_every_n_train_steps != 0
278309 ):
279310 return
280311
281312 self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_train_step_end" )
282313
283314 def on_train_epoch_end (self , state : State , unit : TTrainUnit ) -> None :
284315 epoch = unit .train_progress .num_epochs_completed
285- save_every_n_epochs = self ._save_every_n_epochs
286- if save_every_n_epochs is None or epoch % save_every_n_epochs != 0 :
316+ if not self ._save_every_n_epochs or epoch % self ._save_every_n_epochs != 0 :
287317 return
288318
289319 self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_train_epoch_end" )
290320
321+ def on_train_end (self , state : State , unit : TTrainUnit ) -> None :
322+ self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_train_end" )
323+
324+ def on_eval_start (self , state : State , unit : TEvalUnit ) -> None :
325+ if state .entry_point == EntryPoint .EVALUATE :
326+ self ._disable_ckpt_optimality_tracking ()
327+
328+ def on_eval_step_end (self , state : State , unit : TEvalUnit ) -> None :
329+ num_steps_completed = unit .eval_progress .num_steps_completed
330+ if (
331+ not self ._save_every_n_eval_steps
332+ or num_steps_completed % self ._save_every_n_eval_steps != 0
333+ ):
334+ return
335+
336+ self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_eval_step_end" )
337+
291338 def on_eval_epoch_end (self , state : State , unit : TEvalUnit ) -> None :
292339 epoch = unit .eval_progress .num_epochs_completed
293- save_every_n_eval_epochs = self ._save_every_n_eval_epochs
294- if save_every_n_eval_epochs is None or epoch % save_every_n_eval_epochs != 0 :
340+ if (
341+ not self ._save_every_n_eval_epochs
342+ or epoch % self ._save_every_n_eval_epochs != 0
343+ ):
295344 return
296345
297346 self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_eval_epoch_end" )
298347
299- def on_train_end (self , state : State , unit : TTrainUnit ) -> None :
300- self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_train_end" )
348+ def on_predict_start (self , state : State , unit : TPredictUnit ) -> None :
349+ self ._disable_ckpt_optimality_tracking ()
350+
351+ def on_predict_step_end (self , state : State , unit : TPredictUnit ) -> None :
352+ num_steps_completed = unit .predict_progress .num_steps_completed
353+ if (
354+ not self ._save_every_n_predict_steps
355+ or num_steps_completed % self ._save_every_n_predict_steps != 0
356+ ):
357+ return
358+
359+ self ._generate_checkpoint_and_upkeep (state , unit , hook = "on_predict_step_end" )
360+
361+ def _disable_ckpt_optimality_tracking (self ) -> None :
362+ """
363+ Disables checkpoint optimality tracking. This means that best_checkpoint and keep_last_n_checkpoints
364+ will not be used. This is useful for eval and predict entrypoints, since checkpoints do not include
365+ model parameters.
366+ """
367+ if self ._best_checkpoint_config :
368+ logger .warning (
369+ "Disabling best_checkpoint_config, since it is not supported for eval or predict entrypoints."
370+ )
371+ self ._best_checkpoint_config = None
372+ self ._checkpoint_manager ._best_checkpoint_config = None
373+
374+ if self ._keep_last_n_checkpoints :
375+ logger .warning (
376+ "Disabling keep_last_n_checkpoints, since is not supported for eval or predict entrypoints."
377+ )
378+ self ._keep_last_n_checkpoints = None
379+ self ._checkpoint_manager ._keep_last_n_checkpoints = None
301380
302381 @abc .abstractmethod
303382 def _checkpoint_impl (
0 commit comments