|
10 | 10 | import logging |
11 | 11 | import math |
12 | 12 | from datetime import timedelta |
13 | | -from typing import Any, cast, Iterable, List, Literal, Optional, Union |
| 13 | +from typing import Any, cast, Dict, Iterable, List, Literal, Optional, Union |
14 | 14 |
|
15 | 15 | import fsspec |
16 | 16 |
|
|
39 | 39 | Phase, |
40 | 40 | ) |
41 | 41 | from torchtnt.utils.distributed import get_world_size, PGWrapper |
| 42 | +from torchtnt.utils.event_handlers import log_interval |
42 | 43 | from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn |
43 | 44 |
|
44 | 45 | logger: logging.Logger = logging.getLogger(__name__) |
@@ -201,70 +202,85 @@ def _generate_checkpoint_and_upkeep( |
201 | 202 | Returns: |
202 | 203 | True if checkpoint was successfully saved. False otherwise. |
203 | 204 | """ |
204 | | - # 1) generate checkpoint name |
205 | | - epoch = _get_epoch(state, unit) |
206 | | - step_mapping = _get_step_phase_mapping(state, unit) |
207 | | - |
208 | | - # 1.1) append metric data only if best_checkpoint_config is defined |
209 | | - metric_data: Optional[MetricData] = None |
210 | | - if self._best_checkpoint_config and ( |
211 | | - metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit)) |
212 | | - ): |
213 | | - metric_data = MetricData( |
214 | | - name=none_throws(self._best_checkpoint_config).monitored_metric, |
215 | | - value=metric_value, |
216 | | - ) |
217 | | - |
218 | | - checkpoint_path = self._checkpoint_manager.generate_checkpoint_path( |
219 | | - epoch, |
220 | | - step_mapping, |
221 | | - metric_data, |
222 | | - process_group=self._process_group, |
223 | | - ) |
224 | | - |
225 | | - # 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints |
226 | | - # since neither best_checkpoint_config nor keep_last_n_checkpoints are supported. |
227 | | - if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path): |
228 | | - return False |
| 205 | + log_interval_metadata: Dict[str, str] = { |
| 206 | + "category": "checkpointing", |
| 207 | + "active_phase": str(state.active_phase), |
| 208 | + "hook": hook, |
| 209 | + "epoch": str(_get_epoch(state, unit)), |
| 210 | + "step": str( |
| 211 | + _get_step_phase_mapping(state, unit).get( |
| 212 | + state.active_phase.into_phase(), 0 |
| 213 | + ) |
| 214 | + ), |
| 215 | + } |
229 | 216 |
|
230 | | - if hook == "on_train_end": |
231 | | - # 2.1) Make sure that last checkpoint does not already exist |
232 | | - if self._checkpoint_manager.does_checkpoint_exist( |
233 | | - checkpoint_path, self._process_group |
| 217 | + with log_interval( |
| 218 | + "_generate_checkpoint_and_upkeep", metadata=log_interval_metadata |
| 219 | + ): |
| 220 | + # 1) generate checkpoint name |
| 221 | + epoch = _get_epoch(state, unit) |
| 222 | + step_mapping = _get_step_phase_mapping(state, unit) |
| 223 | + |
| 224 | + # 1.1) append metric data only if best_checkpoint_config is defined |
| 225 | + metric_data: Optional[MetricData] = None |
| 226 | + if self._best_checkpoint_config and ( |
| 227 | + metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit)) |
234 | 228 | ): |
235 | | - rank_zero_warn( |
236 | | - "Final checkpoint already exists, skipping.", logger=logger |
| 229 | + metric_data = MetricData( |
| 230 | + name=none_throws(self._best_checkpoint_config).monitored_metric, |
| 231 | + value=metric_value, |
237 | 232 | ) |
| 233 | + |
| 234 | + checkpoint_path = self._checkpoint_manager.generate_checkpoint_path( |
| 235 | + epoch, |
| 236 | + step_mapping, |
| 237 | + metric_data, |
| 238 | + process_group=self._process_group, |
| 239 | + ) |
| 240 | + |
| 241 | + # 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints |
| 242 | + # since neither best_checkpoint_config nor keep_last_n_checkpoints are supported. |
| 243 | + if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path): |
238 | 244 | return False |
239 | 245 |
|
240 | | - # 2.2) If doing fit without eval checkpointing, only consider training progress when |
241 | | - # checking if last checkpoint exists. |
242 | | - if ( |
243 | | - state.entry_point == EntryPoint.FIT |
244 | | - and self._save_every_n_eval_epochs is None |
245 | | - and self._checkpoint_manager._ckpt_paths |
246 | | - and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN] |
247 | | - == cast(TTrainUnit, unit).train_progress.num_steps_completed |
| 246 | + if hook == "on_train_end": |
| 247 | + # 2.1) Make sure that last checkpoint does not already exist |
| 248 | + if self._checkpoint_manager.does_checkpoint_exist( |
| 249 | + checkpoint_path, self._process_group |
| 250 | + ): |
| 251 | + rank_zero_warn( |
| 252 | + "Final checkpoint already exists, skipping.", logger=logger |
| 253 | + ) |
| 254 | + return False |
| 255 | + |
| 256 | + # 2.2) If doing fit without eval checkpointing, only consider training progress when |
| 257 | + # checking if last checkpoint exists. |
| 258 | + if ( |
| 259 | + state.entry_point == EntryPoint.FIT |
| 260 | + and self._save_every_n_eval_epochs is None |
| 261 | + and self._checkpoint_manager._ckpt_paths |
| 262 | + and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN] |
| 263 | + == cast(TTrainUnit, unit).train_progress.num_steps_completed |
| 264 | + ): |
| 265 | + rank_zero_info( |
| 266 | + "Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.", |
| 267 | + logger=logger, |
| 268 | + ) |
| 269 | + return False |
| 270 | + |
| 271 | + # 3) try to save checkpoint |
| 272 | + if not self._checkpoint_impl( |
| 273 | + state, unit, checkpoint_id=checkpoint_path.path, hook=hook |
248 | 274 | ): |
249 | | - rank_zero_info( |
250 | | - "Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.", |
251 | | - logger=logger, |
252 | | - ) |
253 | 275 | return False |
254 | 276 |
|
255 | | - # 3) try to save checkpoint |
256 | | - if not self._checkpoint_impl( |
257 | | - state, unit, checkpoint_id=checkpoint_path.path, hook=hook |
258 | | - ): |
259 | | - return False |
| 277 | + # 4) track checkpoint and clean up surplus if needed |
| 278 | + self._checkpoint_manager.append_checkpoint(checkpoint_path) |
260 | 279 |
|
261 | | - # 4) track checkpoint and clean up surplus if needed |
262 | | - self._checkpoint_manager.append_checkpoint(checkpoint_path) |
| 280 | + # 5) invoke on_checkpoint_save callback on the unit since checkpoint was saved successfully |
| 281 | + unit.on_checkpoint_save(state, checkpoint_id=checkpoint_path.path) |
263 | 282 |
|
264 | | - # 5) invoke on_checkpoint_save callback on the unit since checkpoint was saved successfully |
265 | | - unit.on_checkpoint_save(state, checkpoint_id=checkpoint_path.path) |
266 | | - |
267 | | - return True |
| 283 | + return True |
268 | 284 |
|
269 | 285 | def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]: |
270 | 286 | """ |
|
0 commit comments