Skip to content

Commit b373a76

Browse files
committed
Merge branch 'master' into bump/python_3.9+
2 parents cba3391 + 1f2d7a1 commit b373a76

File tree

9 files changed

+621
-62
lines changed

9 files changed

+621
-62
lines changed

src/lightning/fabric/utilities/throughput.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,14 @@ def measure_flops(
347347
torch.int8: 389.9e12,
348348
"int4": 779.8e12,
349349
},
350+
"rtx 4080 super": {
351+
torch.float32: 52.2e12,
352+
"tfloat32": 52.2e12,
353+
torch.bfloat16: 52.2e12,
354+
torch.float16: 52.2e12,
355+
torch.int8: 417.6e12,
356+
"int4": 835.2e12,
357+
},
350358
"l4": {
351359
torch.float32: 30.3e12,
352360
"tfloat32": 60e12,

src/lightning/pytorch/demos/transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def forward(self, x: Tensor) -> Tensor:
8888
# TODO: Could make this a `nn.Parameter` with `requires_grad=False`
8989
self.pe = self._init_pos_encoding(device=x.device)
9090

91-
x = x + self.pe[: x.size(0), :]
91+
x = x + self.pe[:, x.size(1)]
9292
return self.dropout(x)
9393

9494
def _init_pos_encoding(self, device: torch.device) -> Tensor:
@@ -97,7 +97,7 @@ def _init_pos_encoding(self, device: torch.device) -> Tensor:
9797
div_term = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * (-math.log(10000.0) / self.dim))
9898
pe[:, 0::2] = torch.sin(position * div_term)
9999
pe[:, 1::2] = torch.cos(position * div_term)
100-
pe = pe.unsqueeze(0).transpose(0, 1)
100+
pe = pe.unsqueeze(0)
101101
return pe
102102

103103

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import shutil
1616
import sys
1717
from collections import ChainMap, OrderedDict, defaultdict
18-
from collections.abc import Iterable, Iterator
19-
from typing import Any, Optional, Union
18+
from dataclasses import dataclass
19+
from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union
2020

2121
from lightning_utilities.core.apply_func import apply_to_collection
2222
from torch import Tensor
@@ -46,6 +46,12 @@
4646
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
4747

4848

49+
@dataclass
50+
class RestartStage:
51+
NONE = "none"
52+
RESTARTED_MID_EVALUATION = "restarted_mid_evaluation"
53+
54+
4955
class _EvaluationLoop(_Loop):
5056
"""Top-level loop where validation/testing starts."""
5157

@@ -61,19 +67,20 @@ def __init__(
6167
self.verbose = verbose
6268
self.inference_mode = inference_mode
6369
self.batch_progress = _BatchProgress() # across dataloaders
64-
self._max_batches: list[Union[int, float]] = []
70+
self._max_batches: List[Union[int, float]] = []
6571

6672
self._results = _ResultCollection(training=False)
67-
self._logged_outputs: list[_OUT_DICT] = []
73+
self._logged_outputs: List[_OUT_DICT] = []
6874
self._has_run: bool = False
6975
self._trainer_fn = trainer_fn
7076
self._stage = stage
7177
self._data_source = _DataLoaderSource(None, f"{stage.dataloader_prefix}_dataloader")
7278
self._combined_loader: Optional[CombinedLoader] = None
7379
self._data_fetcher: Optional[_DataFetcher] = None
74-
self._seen_batches_per_dataloader: defaultdict[int, int] = defaultdict(int)
80+
self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int)
7581
self._last_val_dl_reload_epoch = float("-inf")
7682
self._module_mode = _ModuleMode()
83+
self._restart_stage = RestartStage.NONE
7784

7885
@property
7986
def num_dataloaders(self) -> int:
@@ -83,7 +90,7 @@ def num_dataloaders(self) -> int:
8390
return len(combined_loader.flattened)
8491

8592
@property
86-
def max_batches(self) -> list[Union[int, float]]:
93+
def max_batches(self) -> List[Union[int, float]]:
8794
"""The max number of batches to run per dataloader."""
8895
max_batches = self._max_batches
8996
if not self.trainer.sanity_checking:
@@ -107,7 +114,7 @@ def _is_sequential(self) -> bool:
107114
return self._combined_loader._mode == "sequential"
108115

109116
@_no_grad_context
110-
def run(self) -> list[_OUT_DICT]:
117+
def run(self) -> List[_OUT_DICT]:
111118
self.setup_data()
112119
if self.skip:
113120
return []
@@ -138,7 +145,7 @@ def run(self) -> list[_OUT_DICT]:
138145
# this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
139146
break
140147
finally:
141-
self._restarting = False
148+
self.on_iteration_done()
142149
self._store_dataloader_outputs()
143150
return self.on_run_end()
144151

@@ -198,6 +205,24 @@ def setup_data(self) -> None:
198205
# this depends on the data used, so reset it too
199206
self._seen_batches_per_dataloader = defaultdict(int)
200207

208+
@property
209+
def restarted_mid_evaluation(self) -> bool:
210+
return self._restart_stage == RestartStage.RESTARTED_MID_EVALUATION
211+
212+
def update_restart_stage(self) -> None:
213+
if (
214+
self.restarting
215+
and self.batch_progress.total.started == self.batch_progress.total.ready
216+
and self.batch_progress.total.processed == self.batch_progress.total.started - 1
217+
and self.batch_progress.total.completed == self.batch_progress.total.processed
218+
):
219+
self._restart_stage = RestartStage.RESTARTED_MID_EVALUATION
220+
else:
221+
self._restart_stage = RestartStage.NONE
222+
223+
def reset_restart_stage(self) -> None:
224+
self._restart_stage = RestartStage.NONE
225+
201226
def reset(self) -> None:
202227
"""Resets the internal state of the loop."""
203228
trainer = self.trainer
@@ -237,6 +262,16 @@ def reset(self) -> None:
237262
data_fetcher._stop_profiler = self._on_after_fetch
238263
self._data_fetcher = data_fetcher
239264

265+
def increment_progress_to_evaluation_end(self) -> None:
266+
self.setup_data()
267+
if self.skip:
268+
return
269+
self.reset()
270+
max_batch = int(max(self.max_batches))
271+
if max_batch == -1:
272+
return
273+
self.batch_progress.increment_by(max_batch, True)
274+
240275
def on_run_start(self) -> None:
241276
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
242277
hooks."""
@@ -245,7 +280,7 @@ def on_run_start(self) -> None:
245280
self._on_evaluation_start()
246281
self._on_evaluation_epoch_start()
247282

248-
def on_run_end(self) -> list[_OUT_DICT]:
283+
def on_run_end(self) -> List[_OUT_DICT]:
249284
"""Runs the ``_on_evaluation_epoch_end`` hook."""
250285
# if `done` returned True before any iterations were done, this won't have been called in `on_advance_end`
251286
self.trainer._logger_connector.epoch_end_reached()
@@ -473,7 +508,7 @@ def _verify_dataloader_idx_requirement(self) -> None:
473508
)
474509

475510
@staticmethod
476-
def _get_keys(data: dict) -> Iterable[tuple[str, ...]]:
511+
def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]:
477512
for k, v in data.items():
478513
if isinstance(v, dict):
479514
for new_key in apply_to_collection(v, dict, _EvaluationLoop._get_keys):
@@ -492,7 +527,7 @@ def _find_value(data: dict, target: Iterable[str]) -> Optional[Any]:
492527
return _EvaluationLoop._find_value(result, rest)
493528

494529
@staticmethod
495-
def _print_results(results: list[_OUT_DICT], stage: str) -> None:
530+
def _print_results(results: List[_OUT_DICT], stage: str) -> None:
496531
# remove the dl idx suffix
497532
results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results]
498533
metrics_paths = {k for keys in apply_to_collection(results, dict, _EvaluationLoop._get_keys) for k in keys}
@@ -509,7 +544,7 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None:
509544
term_size = shutil.get_terminal_size(fallback=(120, 30)).columns or 120
510545
max_length = int(min(max(len(max(metrics_strs, key=len)), len(max(headers, key=len)), 25), term_size / 2))
511546

512-
rows: list[list[Any]] = [[] for _ in metrics_paths]
547+
rows: List[List[Any]] = [[] for _ in metrics_paths]
513548

514549
for result in results:
515550
for metric, row in zip(metrics_paths, rows):

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 103 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
from typing import Any, Optional, Union
15+
from dataclasses import dataclass
16+
from typing import Any, Dict, List, Optional, Union
1617

1718
import torch
1819
from typing_extensions import override
@@ -45,6 +46,15 @@
4546
log = logging.getLogger(__name__)
4647

4748

49+
@dataclass
50+
class RestartStage:
51+
NONE = "none"
52+
RESTARTED_ON_EPOCH_START = "restarted_on_epoch_start"
53+
RESTARTED_MID_EPOCH = "restarted_mid_epoch"
54+
RESTARTED_ON_EPOCH_END = "restarted_on_epoch_end"
55+
RESUMED_ON_EPOCH_END = "resumed_on_epoch_end"
56+
57+
4858
class _FitLoop(_Loop):
4959
"""This loop is the top-level loop where training starts.
5060
@@ -94,9 +104,10 @@ def __init__(
94104

95105
self._data_source = _DataLoaderSource(None, "train_dataloader")
96106
self._combined_loader: Optional[CombinedLoader] = None
97-
self._combined_loader_states_to_load: list[dict[str, Any]] = []
107+
self._combined_loader_states_to_load: List[Dict[str, Any]] = []
98108
self._data_fetcher: Optional[_DataFetcher] = None
99109
self._last_train_dl_reload_epoch = float("-inf")
110+
self._restart_stage = RestartStage.NONE
100111

101112
@property
102113
def total_batch_idx(self) -> int:
@@ -204,9 +215,10 @@ def run(self) -> None:
204215
self.on_advance_start()
205216
self.advance()
206217
self.on_advance_end()
207-
self._restarting = False
208218
except StopIteration:
209219
break
220+
finally:
221+
self.on_iteration_done()
210222
self._restarting = False
211223
self.on_run_end()
212224

@@ -302,14 +314,92 @@ def setup_data(self) -> None:
302314
category=PossibleUserWarning,
303315
)
304316

317+
@property
318+
def restarted_on_epoch_start(self) -> bool:
319+
return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_START
320+
321+
@property
322+
def restarted_mid_epoch(self) -> bool:
323+
return self._restart_stage == RestartStage.RESTARTED_MID_EPOCH
324+
325+
@property
326+
def restarted_on_epoch_end(self) -> bool:
327+
return self._restart_stage == RestartStage.RESTARTED_ON_EPOCH_END
328+
329+
@property
330+
def resumed_on_epoch_end(self) -> bool:
331+
# This case happens when restarting from last without validation at
332+
# the end of epoch. In this case self.restarting is False.
333+
return self._restart_stage == RestartStage.RESUMED_ON_EPOCH_END
334+
335+
def update_restart_stage(self) -> None:
336+
if (
337+
self.restarting
338+
and self.epoch_progress.total.started == self.epoch_progress.total.ready - 1
339+
and self.epoch_progress.total.processed == self.epoch_progress.total.started
340+
and self.epoch_progress.total.completed == self.epoch_progress.total.processed
341+
):
342+
self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_START
343+
elif (
344+
self.restarting
345+
and self.epoch_progress.total.started == self.epoch_progress.total.ready
346+
and self.epoch_progress.total.processed == self.epoch_progress.total.started - 1
347+
and self.epoch_progress.total.completed == self.epoch_progress.total.processed
348+
):
349+
self._restart_stage = RestartStage.RESTARTED_MID_EPOCH
350+
elif (
351+
self.restarting
352+
and self.epoch_progress.total.started == self.epoch_progress.total.ready
353+
and self.epoch_progress.total.processed == self.epoch_progress.total.started
354+
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
355+
):
356+
self._restart_stage = RestartStage.RESTARTED_ON_EPOCH_END
357+
elif (
358+
self._loaded_from_state_dict
359+
and self.epoch_progress.total.started == self.epoch_progress.total.ready
360+
and self.epoch_progress.total.processed == self.epoch_progress.total.started
361+
and self.epoch_progress.total.completed == self.epoch_progress.total.processed - 1
362+
):
363+
self._restart_stage = RestartStage.RESUMED_ON_EPOCH_END
364+
else:
365+
self._restart_stage = RestartStage.NONE
366+
367+
self.epoch_loop.update_restart_stage()
368+
369+
def reset_restart_stage(self) -> None:
370+
self._restart_stage = RestartStage.NONE
371+
305372
def reset(self) -> None:
306373
"""Resets the internal state of this loop."""
307374
assert self.trainer.model is not None
308375
torch.set_grad_enabled(True)
309376

310-
if self.restarting:
377+
self.update_restart_stage()
378+
379+
if self.restarted_on_epoch_start:
311380
self.epoch_progress.reset_on_restart()
312381

382+
if self.resumed_on_epoch_end:
383+
# when restarting from last without validation at end of epoch,
384+
# self.restarting is False but it's still resuming
385+
self.epoch_progress.increment_completed()
386+
387+
if (
388+
self.epoch_loop.restarted_on_train_batch_end
389+
and self.restarted_mid_epoch
390+
and self.epoch_loop.batch_progress.is_last_batch
391+
):
392+
self.epoch_progress.increment_processed()
393+
self.epoch_progress.increment_completed()
394+
395+
if (
396+
self.epoch_loop.restarted_on_train_batch_end
397+
and self.epoch_loop.batch_progress.is_last_batch
398+
and not self.restarted_mid_epoch
399+
and not self.epoch_loop.val_loop.batch_progress.is_last_batch
400+
):
401+
self.epoch_progress.increment_completed()
402+
313403
def on_run_start(self) -> None:
314404
"""Calls the ``on_train_start`` hook."""
315405
# update the current_epoch in-case of checkpoint reload
@@ -340,12 +430,14 @@ def on_advance_start(self) -> None:
340430
for i, dl in enumerate(self._combined_loader.flattened):
341431
_set_sampler_epoch(dl, self.epoch_progress.current.processed)
342432

343-
self.epoch_progress.increment_ready()
433+
if not self.restarted_mid_epoch and not self.restarted_on_epoch_end:
434+
if not self.restarted_on_epoch_start:
435+
self.epoch_progress.increment_ready()
344436

345-
call._call_callback_hooks(trainer, "on_train_epoch_start")
346-
call._call_lightning_module_hook(trainer, "on_train_epoch_start")
437+
call._call_callback_hooks(trainer, "on_train_epoch_start")
438+
call._call_lightning_module_hook(trainer, "on_train_epoch_start")
347439

348-
self.epoch_progress.increment_started()
440+
self.epoch_progress.increment_started()
349441

350442
def advance(self) -> None:
351443
"""Runs one whole epoch."""
@@ -379,8 +471,7 @@ def on_advance_end(self) -> None:
379471

380472
trainer._logger_connector.on_epoch_end()
381473

382-
if self.epoch_loop._num_ready_batches_reached():
383-
# if we are restarting and the above condition holds, it's because we are reloading an epoch-end checkpoint.
474+
if not self.restarting and self.epoch_loop._num_ready_batches_reached():
384475
# since metric-based schedulers require access to metrics and those are not currently saved in the
385476
# checkpoint, the plateau schedulers shouldn't be updated
386477
self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting)
@@ -413,14 +504,14 @@ def teardown(self) -> None:
413504
self.epoch_loop.teardown()
414505

415506
@override
416-
def on_save_checkpoint(self) -> dict:
507+
def on_save_checkpoint(self) -> Dict:
417508
state_dict = super().on_save_checkpoint()
418509
if self._combined_loader is not None and (loader_states := self._combined_loader._state_dicts()):
419510
state_dict["combined_loader"] = loader_states
420511
return state_dict
421512

422513
@override
423-
def on_load_checkpoint(self, state_dict: dict) -> None:
514+
def on_load_checkpoint(self, state_dict: Dict) -> None:
424515
self._combined_loader_states_to_load = state_dict.get("combined_loader", [])
425516
super().on_load_checkpoint(state_dict)
426517

0 commit comments

Comments
 (0)