|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import logging |
15 | | -from typing import Any, Optional, Union |
| 15 | +from dataclasses import dataclass |
| 16 | +from typing import Any, Dict, List, Optional, Union |
16 | 17 |
|
17 | 18 | import torch |
18 | 19 | from typing_extensions import override |
|
45 | 46 | log = logging.getLogger(__name__) |
46 | 47 |
|
47 | 48 |
|
| 49 | +@dataclass |
| 50 | +class RestartStage: |
| 51 | + NONE = "none" |
| 52 | + RESTARTED_ON_EPOCH_START = "restarted_on_epoch_start" |
| 53 | + RESTARTED_MID_EPOCH = "restarted_mid_epoch" |
| 54 | + RESTARTED_ON_EPOCH_END = "restarted_on_epoch_end" |
| 55 | + RESUMED_ON_EPOCH_END = "resumed_on_epoch_end" |
| 56 | + |
| 57 | + |
48 | 58 | class _FitLoop(_Loop): |
49 | 59 | """This loop is the top-level loop where training starts. |
50 | 60 |
|
@@ -94,9 +104,10 @@ def __init__( |
94 | 104 |
|
95 | 105 | self._data_source = _DataLoaderSource(None, "train_dataloader") |
96 | 106 | self._combined_loader: Optional[CombinedLoader] = None |
97 | | - self._combined_loader_states_to_load: list[dict[str, Any]] = [] |
| 107 | + self._combined_loader_states_to_load: List[Dict[str, Any]] = [] |
98 | 108 | self._data_fetcher: Optional[_DataFetcher] = None |
99 | 109 | self._last_train_dl_reload_epoch = float("-inf") |
| 110 | + self._restart_stage = RestartStage.NONE |
100 | 111 |
|
101 | 112 | @property |
102 | 113 | def total_batch_idx(self) -> int: |
@@ -204,9 +215,10 @@ def run(self) -> None: |
204 | 215 | self.on_advance_start() |
205 | 216 | self.advance() |
206 | 217 | self.on_advance_end() |
207 | | - self._restarting = False |
208 | 218 | except StopIteration: |
209 | 219 | break |
| 220 | + finally: |
| 221 | + self.on_iteration_done() |
210 | 222 | self._restarting = False |
211 | 223 | self.on_run_end() |
212 | 224 |
|
@@ -302,14 +314,92 @@ def setup_data(self) -> None: |
302 | 314 | category=PossibleUserWarning, |
303 | 315 | ) |
304 | 316 |
|
| 317 | + @property |
| 318 | + def restarted_on_epoch_start(self) -> bool: |
| 319 | + return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_START |
| 320 | + |
| 321 | + @property |
| 322 | + def restarted_mid_epoch(self) -> bool: |
| 323 | + return self._restart_stage == RestartStage.RESTARTED_MID_EPOCH |
| 324 | + |
| 325 | + @property |
| 326 | + def restarted_on_epoch_end(self) -> bool: |
| 327 | + return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_END |
| 328 | + |
| 329 | + @property |
| 330 | + def resumed_on_epoch_end(self) -> bool: |
| 331 | + # This case happens when restarting from last without validation at |
| 332 | + # the end of epoch. In this case self.restarting is False. |
| 333 | + return self._restart_stage == RestartStage.RESUMED_ON_EPOCH_END |
| 334 | + |
| 335 | + def update_restart_stage(self) -> None: |
| 336 | + if ( |
| 337 | + self.restarting |
| 338 | + and self.epoch_progress.total.started == self.epoch_progress.total.ready - 1 |
| 339 | + and self.epoch_progress.total.processed == self.epoch_progress.total.started |
| 340 | + and self.epoch_progress.total.completed == self.epoch_progress.total.processed |
| 341 | + ): |
| 342 | + self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_START |
| 343 | + elif ( |
| 344 | + self.restarting |
| 345 | + and self.epoch_progress.total.started == self.epoch_progress.total.ready |
| 346 | + and self.epoch_progress.total.processed == self.epoch_progress.total.started - 1 |
| 347 | + and self.epoch_progress.total.completed == self.epoch_progress.total.processed |
| 348 | + ): |
| 349 | + self._restart_stage = RestartStage.RESTARTED_MID_EPOCH |
| 350 | + elif ( |
| 351 | + self.restarting |
| 352 | + and self.epoch_progress.total.started == self.epoch_progress.total.ready |
| 353 | + and self.epoch_progress.total.processed == self.epoch_progress.total.started |
| 354 | + and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 |
| 355 | + ): |
| 356 | + self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_END |
| 357 | + elif ( |
| 358 | + self._loaded_from_state_dict |
| 359 | + and self.epoch_progress.total.started == self.epoch_progress.total.ready |
| 360 | + and self.epoch_progress.total.processed == self.epoch_progress.total.started |
| 361 | + and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 |
| 362 | + ): |
| 363 | + self._restart_stage = RestartStage.RESUMED_ON_EPOCH_END |
| 364 | + else: |
| 365 | + self._restart_stage = RestartStage.NONE |
| 366 | + |
| 367 | + self.epoch_loop.update_restart_stage() |
| 368 | + |
| 369 | + def reset_restart_stage(self) -> None: |
| 370 | + self._restart_stage = RestartStage.NONE |
| 371 | + |
305 | 372 | def reset(self) -> None: |
306 | 373 | """Resets the internal state of this loop.""" |
307 | 374 | assert self.trainer.model is not None |
308 | 375 | torch.set_grad_enabled(True) |
309 | 376 |
|
310 | | - if self.restarting: |
| 377 | + self.update_restart_stage() |
| 378 | + |
| 379 | + if self.restarted_on_epoch_start: |
311 | 380 | self.epoch_progress.reset_on_restart() |
312 | 381 |
|
| 382 | + if self.resumed_on_epoch_end: |
| 383 | + # when restarting from last without validation at end of epoch, |
| 384 | + # self.restarting is False but it's still resuming |
| 385 | + self.epoch_progress.increment_completed() |
| 386 | + |
| 387 | + if ( |
| 388 | + self.epoch_loop.restarted_on_train_batch_end |
| 389 | + and self.restarted_mid_epoch |
| 390 | + and self.epoch_loop.batch_progress.is_last_batch |
| 391 | + ): |
| 392 | + self.epoch_progress.increment_processed() |
| 393 | + self.epoch_progress.increment_completed() |
| 394 | + |
| 395 | + if ( |
| 396 | + self.epoch_loop.restarted_on_train_batch_end |
| 397 | + and self.epoch_loop.batch_progress.is_last_batch |
| 398 | + and not self.restarted_mid_epoch |
| 399 | + and not self.epoch_loop.val_loop.batch_progress.is_last_batch |
| 400 | + ): |
| 401 | + self.epoch_progress.increment_completed() |
| 402 | + |
313 | 403 | def on_run_start(self) -> None: |
314 | 404 | """Calls the ``on_train_start`` hook.""" |
315 | 405 | # update the current_epoch in-case of checkpoint reload |
@@ -340,12 +430,14 @@ def on_advance_start(self) -> None: |
340 | 430 | for i, dl in enumerate(self._combined_loader.flattened): |
341 | 431 | _set_sampler_epoch(dl, self.epoch_progress.current.processed) |
342 | 432 |
|
343 | | - self.epoch_progress.increment_ready() |
| 433 | + if not self.restarted_mid_epoch and not self.restarted_on_epoch_end: |
| 434 | + if not self.restarted_on_epoch_start: |
| 435 | + self.epoch_progress.increment_ready() |
344 | 436 |
|
345 | | - call._call_callback_hooks(trainer, "on_train_epoch_start") |
346 | | - call._call_lightning_module_hook(trainer, "on_train_epoch_start") |
| 437 | + call._call_callback_hooks(trainer, "on_train_epoch_start") |
| 438 | + call._call_lightning_module_hook(trainer, "on_train_epoch_start") |
347 | 439 |
|
348 | | - self.epoch_progress.increment_started() |
| 440 | + self.epoch_progress.increment_started() |
349 | 441 |
|
350 | 442 | def advance(self) -> None: |
351 | 443 | """Runs one whole epoch.""" |
@@ -379,8 +471,7 @@ def on_advance_end(self) -> None: |
379 | 471 |
|
380 | 472 | trainer._logger_connector.on_epoch_end() |
381 | 473 |
|
382 | | - if self.epoch_loop._num_ready_batches_reached(): |
383 | | - # if we are restarting and the above condition holds, it's because we are reloading an epoch-end checkpoint. |
| 474 | + if not self.restarting and self.epoch_loop._num_ready_batches_reached(): |
384 | 475 | # since metric-based schedulers require access to metrics and those are not currently saved in the |
385 | 476 | # checkpoint, the plateau schedulers shouldn't be updated |
386 | 477 | self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting) |
@@ -413,14 +504,14 @@ def teardown(self) -> None: |
413 | 504 | self.epoch_loop.teardown() |
414 | 505 |
|
415 | 506 | @override |
416 | | - def on_save_checkpoint(self) -> dict: |
| 507 | + def on_save_checkpoint(self) -> Dict: |
417 | 508 | state_dict = super().on_save_checkpoint() |
418 | 509 | if self._combined_loader is not None and (loader_states := self._combined_loader._state_dicts()): |
419 | 510 | state_dict["combined_loader"] = loader_states |
420 | 511 | return state_dict |
421 | 512 |
|
422 | 513 | @override |
423 | | - def on_load_checkpoint(self, state_dict: dict) -> None: |
| 514 | + def on_load_checkpoint(self, state_dict: Dict) -> None: |
424 | 515 | self._combined_loader_states_to_load = state_dict.get("combined_loader", []) |
425 | 516 | super().on_load_checkpoint(state_dict) |
426 | 517 |
|
|
0 commit comments