|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import logging |
| 15 | +from dataclasses import dataclass |
15 | 16 | from typing import Any, Dict, List, Optional, Union |
16 | 17 |
|
17 | 18 | import torch |
|
45 | 46 | log = logging.getLogger(__name__) |
46 | 47 |
|
47 | 48 |
|
| 49 | +@dataclass |
| 50 | +class RestartStage: |
| 51 | + NONE = "none" |
| 52 | + RESTARTED_ON_EPOCH_START = "restarted_on_epoch_start" |
| 53 | + RESTARTED_MID_EPOCH = "restarted_mid_epoch" |
| 54 | + RESTARTED_ON_EPOCH_END = "restarted_on_epoch_end" |
| 55 | + RESUMED_ON_EPOCH_END = "resumed_on_epoch_end" |
| 56 | + |
| 57 | + |
48 | 58 | class _FitLoop(_Loop): |
49 | 59 | """This loop is the top-level loop where training starts. |
50 | 60 |
|
@@ -97,6 +107,7 @@ def __init__( |
97 | 107 | self._combined_loader_states_to_load: List[Dict[str, Any]] = [] |
98 | 108 | self._data_fetcher: Optional[_DataFetcher] = None |
99 | 109 | self._last_train_dl_reload_epoch = float("-inf") |
| 110 | + self._restart_stage = RestartStage.NONE |
100 | 111 |
|
101 | 112 | @property |
102 | 113 | def total_batch_idx(self) -> int: |
@@ -204,9 +215,10 @@ def run(self) -> None: |
204 | 215 | self.on_advance_start() |
205 | 216 | self.advance() |
206 | 217 | self.on_advance_end() |
207 | | - self._restarting = False |
208 | 218 | except StopIteration: |
209 | 219 | break |
| 220 | + finally: |
| 221 | + self.on_iteration_done() |
210 | 222 | self._restarting = False |
211 | 223 | self.on_run_end() |
212 | 224 |
|
@@ -303,67 +315,89 @@ def setup_data(self) -> None: |
303 | 315 | ) |
304 | 316 |
|
305 | 317 | @property |
306 | | - def restarting_on_epoch_start(self) -> bool: |
307 | | - return ( |
| 318 | + def restarted_on_epoch_start(self) -> bool: |
| 319 | + return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_START |
| 320 | + |
| 321 | + @property |
| 322 | + def restarted_mid_epoch(self) -> bool: |
| 323 | + return self._restart_stage == RestartStage.RESTARTED_MID_EPOCH |
| 324 | + |
| 325 | + @property |
| 326 | + def restarted_on_epoch_end(self) -> bool: |
| 327 | + return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_END |
| 328 | + |
| 329 | + @property |
| 330 | + def resumed_on_epoch_end(self) -> bool: |
| 331 | + # This case happens when restarting from last without validation at |
| 332 | + # the end of epoch. In this case self.restarting is False. |
| 333 | + return self._restart_stage == RestartStage.RESUMED_ON_EPOCH_END |
| 334 | + |
| 335 | + def update_restart_stage(self) -> None: |
| 336 | + if ( |
308 | 337 | self.restarting |
309 | 338 | and self.epoch_progress.total.started == self.epoch_progress.total.ready - 1 |
310 | 339 | and self.epoch_progress.total.processed == self.epoch_progress.total.started |
311 | 340 | and self.epoch_progress.total.completed == self.epoch_progress.total.processed |
312 | | - ) |
313 | | - |
314 | | - @property |
315 | | - def restarting_mid_epoch(self) -> bool: |
316 | | - return ( |
| 341 | + ): |
| 342 | + self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_START |
| 343 | + elif ( |
317 | 344 | self.restarting |
318 | 345 | and self.epoch_progress.total.started == self.epoch_progress.total.ready |
319 | 346 | and self.epoch_progress.total.processed == self.epoch_progress.total.started - 1 |
320 | 347 | and self.epoch_progress.total.completed == self.epoch_progress.total.processed |
321 | | - ) |
322 | | - |
323 | | - @property |
324 | | - def restarting_on_epoch_end(self) -> bool: |
325 | | - return ( |
| 348 | + ): |
| 349 | + self._restart_stage = RestartStage.RESTARTED_MID_EPOCH |
| 350 | + elif ( |
326 | 351 | self.restarting |
327 | 352 | and self.epoch_progress.total.started == self.epoch_progress.total.ready |
328 | 353 | and self.epoch_progress.total.processed == self.epoch_progress.total.started |
329 | 354 | and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 |
330 | | - ) |
331 | | - |
332 | | - @property |
333 | | - def progress_at_epoch_end(self) -> bool: |
334 | | - # TODO LUCA comment for restart last without val |
335 | | - return ( |
336 | | - self.epoch_progress.total.started == self.epoch_progress.total.ready |
| 355 | + ): |
| 356 | + self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_END |
| 357 | + elif ( |
| 358 | + self._loaded_from_state_dict |
| 359 | + and self.epoch_progress.total.started == self.epoch_progress.total.ready |
337 | 360 | and self.epoch_progress.total.processed == self.epoch_progress.total.started |
338 | 361 | and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1 |
339 | | - ) |
| 362 | + ): |
| 363 | + self._restart_stage = RestartStage.RESUMED_ON_EPOCH_END |
| 364 | + else: |
| 365 | + self._restart_stage = RestartStage.NONE |
| 366 | + |
| 367 | + self.epoch_loop.update_restart_stage() |
| 368 | + |
| 369 | + def reset_restart_stage(self) -> None: |
| 370 | + self._restart_stage = RestartStage.NONE |
340 | 371 |
|
341 | 372 | def reset(self) -> None: |
342 | 373 | """Resets the internal state of this loop.""" |
343 | 374 | assert self.trainer.model is not None |
344 | 375 | torch.set_grad_enabled(True) |
345 | 376 |
|
346 | | - self.epoch_loop.reset_restarting_states() |
| 377 | + self.update_restart_stage() |
347 | 378 |
|
348 | | - if self.restarting_on_epoch_start: |
| 379 | + if self.restarted_on_epoch_start: |
349 | 380 | self.epoch_progress.reset_on_restart() |
350 | 381 |
|
351 | | - if self.progress_at_epoch_end: |
| 382 | + if self.resumed_on_epoch_end: |
| 383 | + # when restarting from last without validation at end of epoch, |
| 384 | + # self.restarting is False but it's still resuming |
352 | 385 | self.epoch_progress.increment_completed() |
353 | 386 |
|
354 | | - # TODO LUCA: refactor restarting for fit_loop |
355 | | - restarting_mid_epoch = self.restarting_mid_epoch |
356 | | - |
357 | | - if (self.epoch_loop.restarting_on_train_batch_end |
358 | | - and self.restarting_mid_epoch |
359 | | - and self.epoch_loop.batch_progress.is_last_batch): |
| 387 | + if ( |
| 388 | + self.epoch_loop.restarted_on_train_batch_end |
| 389 | + and self.restarted_mid_epoch |
| 390 | + and self.epoch_loop.batch_progress.is_last_batch |
| 391 | + ): |
360 | 392 | self.epoch_progress.increment_processed() |
361 | 393 | self.epoch_progress.increment_completed() |
362 | 394 |
|
363 | | - if (self.epoch_loop.restarting_on_train_batch_end |
| 395 | + if ( |
| 396 | + self.epoch_loop.restarted_on_train_batch_end |
364 | 397 | and self.epoch_loop.batch_progress.is_last_batch |
365 | | - and not restarting_mid_epoch |
366 | | - and not self.epoch_loop.val_loop.batch_progress.is_last_batch): |
| 398 | + and not self.restarted_mid_epoch |
| 399 | + and not self.epoch_loop.val_loop.batch_progress.is_last_batch |
| 400 | + ): |
367 | 401 | self.epoch_progress.increment_completed() |
368 | 402 |
|
369 | 403 | def on_run_start(self) -> None: |
@@ -396,8 +430,8 @@ def on_advance_start(self) -> None: |
396 | 430 | for i, dl in enumerate(self._combined_loader.flattened): |
397 | 431 | _set_sampler_epoch(dl, self.epoch_progress.current.processed) |
398 | 432 |
|
399 | | - if not self.restarting_mid_epoch and not self.restarting_on_epoch_end: |
400 | | - if not self.restarting_on_epoch_start: |
| 433 | + if not self.restarted_mid_epoch and not self.restarted_on_epoch_end: |
| 434 | + if not self.restarted_on_epoch_start: |
401 | 435 | self.epoch_progress.increment_ready() |
402 | 436 |
|
403 | 437 | call._call_callback_hooks(trainer, "on_train_epoch_start") |
|
0 commit comments