|
9 | 9 |
|
10 | 10 | from typing import Any, cast, Dict, Union |
11 | 11 |
|
| 12 | +from pyre_extensions import none_throws |
| 13 | + |
12 | 14 | from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions |
13 | | -from torchtnt.framework.state import EntryPoint, State |
| 15 | +from torchtnt.framework.state import ActivePhase, EntryPoint, State |
14 | 16 | from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TPredictUnit, TTrainUnit |
15 | 17 | from torchtnt.utils.checkpoint import Phase |
16 | 18 |
|
@@ -123,13 +125,27 @@ def _prepare_app_state_for_checkpoint( |
123 | 125 | remove_lr_schedulers=True, |
124 | 126 | ) |
125 | 127 |
|
| 128 | + if not intra_epoch: |
| 129 | + return app_state |
| 130 | + |
126 | 131 | # for intra-epoch checkpointing, include dataloader state of the current phase |
127 | | - phase_dl = state.active_phase_state().dataloader |
128 | | - if intra_epoch and isinstance(phase_dl, Stateful): |
129 | | - dataloader_state_key = _PHASE_DL_STATE_KEY_MAPPING[ |
130 | | - state.active_phase.into_phase() |
131 | | - ] |
132 | | - app_state[dataloader_state_key] = phase_dl |
| 132 | + active_dataloaders = {state.active_phase: state.active_phase_state().dataloader} |
| 133 | + |
| 134 | + # Special case for FIT where eval is executed every n steps. We also need to save |
| 135 | + # the train dataloader state. In this case, train epoch wouldn't be incremented yet. |
| 136 | + if ( |
| 137 | + state.entry_point == EntryPoint.FIT |
| 138 | + and state.active_phase == ActivePhase.EVALUATE |
| 139 | + and cast(TTrainUnit, unit).train_progress.num_steps_completed_in_epoch != 0 |
| 140 | + ): |
| 141 | + active_dataloaders[ActivePhase.TRAIN] = none_throws( |
| 142 | + state.train_state |
| 143 | + ).dataloader |
| 144 | + |
| 145 | + for active_phase, dl in active_dataloaders.items(): |
| 146 | + if isinstance(dl, Stateful): |
| 147 | + dl_key = _PHASE_DL_STATE_KEY_MAPPING[active_phase.into_phase()] |
| 148 | + app_state[dl_key] = dl |
133 | 149 |
|
134 | 150 | return app_state |
135 | 151 |
|
|
0 commit comments