Skip to content

Commit 2b81c20

Browse files
authored
Merge branch 'master' into nitpick/add-make-command
2 parents 53b93c9 + 89dbc55 commit 2b81c20

File tree

13 files changed

+112
-12
lines changed

13 files changed

+112
-12
lines changed

docs/source-pytorch/accelerators/accelerator_prepare.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ Synchronize validation and test logging
7878
***************************************
7979

8080
When running in distributed mode, we have to ensure that the validation and test step logging calls are synchronized across processes.
81-
This is done by adding ``sync_dist=True`` to all ``self.log`` calls in the validation and test step.
81+
This is done by adding ``sync_dist=True`` to all ``self.log`` calls in the validation and test step. This will automatically average values across all processes.
8282
This ensures that each GPU worker has the same behaviour when tracking model checkpoints, which is important for later downstream tasks such as testing the best checkpoint across all workers.
8383
The ``sync_dist`` option can also be used in logging calls during the step methods, but be aware that this can lead to significant communication overhead and slow down your training.
8484

docs/source-pytorch/common/checkpointing_basic.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ The LightningModule also has access to the Hyperparameters
111111
.. code-block:: python
112112
113113
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
114-
print(model.learning_rate)
114+
print(model.hparams.learning_rate)
115115
116116
----
117117

docs/source-pytorch/extensions/logging.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ The :meth:`~lightning.pytorch.core.LightningModule.log` method has a few options
137137
* ``logger``: Logs to the logger like ``Tensorboard``, or any other custom logger passed to the :class:`~lightning.pytorch.trainer.trainer.Trainer` (Default: ``True``).
138138
* ``reduce_fx``: Reduction function over step values for end of epoch. Uses :func:`torch.mean` by default and is not applied when a :class:`torchmetrics.Metric` is logged.
139139
* ``enable_graph``: If True, will not auto detach the graph.
140-
* ``sync_dist``: If True, reduces the metric across devices. Use with care as this may lead to a significant communication overhead.
140+
* ``sync_dist``: If True, averages the metric across devices. Use with care as this may lead to a significant communication overhead.
141141
* ``sync_dist_group``: The DDP group to sync across.
142142
* ``add_dataloader_idx``: If True, appends the index of the current dataloader to the name (when using multiple dataloaders). If False, user needs to give unique names for each dataloader to not mix the values.
143143
* ``batch_size``: Current batch size used for accumulating logs logged with ``on_epoch=True``. This will be directly inferred from the loaded batch, but for some data structures you might need to explicitly provide it.

examples/fabric/image_classifier/train_fabric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def run(hparams):
158158
# When using distributed training, use `fabric.save`
159159
# to ensure the current process is allowed to save a checkpoint
160160
if hparams.save_model:
161-
fabric.save(model.state_dict(), "mnist_cnn.pt")
161+
fabric.save(path="mnist_cnn.pt", state=model.state_dict())
162162

163163

164164
if __name__ == "__main__":

examples/fabric/kfold_cv/train_fabric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def run(hparams):
161161
# When using distributed training, use `fabric.save`
162162
# to ensure the current process is allowed to save a checkpoint
163163
if hparams.save_model:
164-
fabric.save(model.state_dict(), "mnist_cnn.pt")
164+
fabric.save(path="mnist_cnn.pt", state=model.state_dict())
165165

166166

167167
if __name__ == "__main__":

examples/fabric/tensor_parallel/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def train():
6767
# See `fabric consolidate --help` if you need to convert the checkpoint to a single file
6868
fabric.print("Saving a (distributed) checkpoint ...")
6969
state = {"model": model, "optimizer": optimizer, "iteration": i}
70-
fabric.save("checkpoint.pt", state)
70+
fabric.save(path="checkpoint.pt", state=state)
7171

7272
fabric.print("Training successfully completed!")
7373
fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2626
### Fixed
2727

2828
- fix progress bar console clearing for Rich `14.1+` ([#21016](https://github.com/Lightning-AI/pytorch-lightning/pull/21016))
29+
- fix `AdvancedProfiler` to handle nested profiling actions for Python 3.12+ ([#20809](https://github.com/Lightning-AI/pytorch-lightning/pull/20809))
2930

3031

3132
---

src/lightning/pytorch/loops/loop.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class _Loop:
2323
def __init__(self, trainer: "pl.Trainer") -> None:
2424
self._restarting = False
2525
self._loaded_from_state_dict = False
26+
self._resuming_from_checkpoint = False
2627
self.trainer = trainer
2728

2829
@property
@@ -38,6 +39,11 @@ def restarting(self, restarting: bool) -> None:
3839
if isinstance(loop, _Loop):
3940
loop.restarting = restarting
4041

42+
@property
43+
def is_resuming(self) -> bool:
44+
"""Indicates whether training is being resumed from a checkpoint."""
45+
return self._resuming_from_checkpoint
46+
4147
def reset_restart_stage(self) -> None:
4248
pass
4349

@@ -87,6 +93,7 @@ def load_state_dict(
8793
v.load_state_dict(state_dict.copy(), prefix + k + ".")
8894
self.restarting = True
8995
self._loaded_from_state_dict = True
96+
self._resuming_from_checkpoint = True
9097

9198
def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
9299
for k, v in self.__dict__.items():
@@ -102,4 +109,5 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
102109
def on_iteration_done(self) -> None:
103110
self._restarting = False
104111
self._loaded_from_state_dict = False
112+
self._resuming_from_checkpoint = False
105113
self.reset_restart_stage()

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,11 @@ def reset(self) -> None:
237237

238238
def on_run_start(self, data_fetcher: _DataFetcher) -> None:
239239
# `iter()` was called once in `FitLoop.setup_data()` already
240-
if self.trainer.current_epoch > 0 and not self.restarting:
240+
# Call `iter()` again only when:
241+
# 1. Not restarting
242+
# 2. Not resuming from checkpoint (not is_resuming)
243+
# 3. Past first epoch (current_epoch > 0)
244+
if self.trainer.current_epoch > 0 and not self.trainer.fit_loop.is_resuming and not self.restarting:
241245
iter(data_fetcher) # creates the iterator inside the fetcher
242246

243247
# add the previous `fetched` value to properly track `is_last_batch` with no prefetching

src/lightning/pytorch/profilers/advanced.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import pstats
2121
import tempfile
22+
from collections import defaultdict
2223
from pathlib import Path
2324
from typing import Optional, Union
2425

@@ -66,14 +67,15 @@ def __init__(
6667
If you attempt to stop recording an action which was never started.
6768
"""
6869
super().__init__(dirpath=dirpath, filename=filename)
69-
self.profiled_actions: dict[str, cProfile.Profile] = {}
70+
self.profiled_actions: dict[str, cProfile.Profile] = defaultdict(cProfile.Profile)
7071
self.line_count_restriction = line_count_restriction
7172
self.dump_stats = dump_stats
7273

7374
@override
7475
def start(self, action_name: str) -> None:
75-
if action_name not in self.profiled_actions:
76-
self.profiled_actions[action_name] = cProfile.Profile()
76+
# Disable all profilers before starting a new one
77+
for pr in self.profiled_actions.values():
78+
pr.disable()
7779
self.profiled_actions[action_name].enable()
7880

7981
@override
@@ -114,7 +116,7 @@ def summary(self) -> str:
114116
@override
115117
def teardown(self, stage: Optional[str]) -> None:
116118
super().teardown(stage=stage)
117-
self.profiled_actions = {}
119+
self.profiled_actions.clear()
118120

119121
def __reduce__(self) -> tuple:
120122
# avoids `TypeError: cannot pickle 'cProfile.Profile' object`

0 commit comments

Comments
 (0)