Skip to content

Commit 773001a

Browse files
authored
Merge branch 'master' into deepspeed_mics_init
2 parents 6ca2bac + 9709c64 commit 773001a

File tree

10 files changed

+343
-6
lines changed

10 files changed

+343
-6
lines changed

docs/source-pytorch/common/checkpointing_basic.rst

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ PyTorch Lightning checkpoints are fully usable in plain PyTorch.
2020

2121
----
2222

23+
.. important::
24+
25+
**Important Update: Deprecated Method**
26+
27+
Starting from PyTorch Lightning v1.0.0, the `resume_from_checkpoint` argument has been deprecated. To resume training from a checkpoint, use the `ckpt_path` argument in the `fit()` method.
28+
Please update your code accordingly to avoid potential compatibility issues.
29+
2330
************************
2431
Contents of a checkpoint
2532
************************
@@ -197,16 +204,31 @@ You can disable checkpointing by passing:
197204

198205
----
199206

207+
200208
*********************
201209
Resume training state
202210
*********************
203211

204212
If you don't just want to load weights, but instead restore the full training, do the following:
205213

214+
Correct usage:
215+
206216
.. code-block:: python
207217
208218
model = LitModel()
209219
trainer = Trainer()
210220
211221
# automatically restores model, epoch, step, LR schedulers, etc...
212-
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
222+
trainer.fit(model, ckpt_path="path/to/your/checkpoint.ckpt")
223+
224+
.. warning::
225+
226+
The argument `resume_from_checkpoint` has been deprecated in versions of PyTorch Lightning >= 1.0.0.
227+
To resume training from a checkpoint, use the `ckpt_path` argument in the `fit()` method instead.
228+
229+
Incorrect (deprecated) usage:
230+
231+
.. code-block:: python
232+
233+
trainer = Trainer(resume_from_checkpoint="path/to/your/checkpoint.ckpt")
234+
trainer.fit(model)

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
7474
}
7575
self._desired_input_dtype = precision_to_type[self.precision]
7676

77+
@override
78+
def convert_module(self, module: Module) -> Module:
79+
if "true" in self.precision:
80+
return module.to(dtype=self._desired_input_dtype)
81+
return module
82+
7783
@property
7884
def mixed_precision_config(self) -> "TorchMixedPrecision":
7985
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535
- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))
3636
- Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019))
3737
- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163))
38+
- Fixed PyTorch Lightning FSDP takes more memory than PyTorch FSDP ([#20323](https://github.com/Lightning-AI/pytorch-lightning/pull/20323))
3839

3940

4041
## [2.3.0] - 2024-06-13

src/lightning/pytorch/core/datamodule.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
"""LightningDataModule for loading DataLoaders with ease."""
1515

1616
import inspect
17-
from collections.abc import Iterable
17+
import os
18+
from collections.abc import Iterable, Sized
1819
from typing import IO, Any, Optional, Union, cast
1920

2021
from lightning_utilities import apply_to_collection
@@ -244,3 +245,75 @@ def load_from_checkpoint(
244245
**kwargs,
245246
)
246247
return cast(Self, loaded)
248+
249+
def __str__(self) -> str:
250+
"""Return a string representation of the datasets that are set up.
251+
252+
Returns:
253+
A string representation of the datasets that are setup.
254+
255+
"""
256+
257+
class dataset_info:
258+
def __init__(self, available: bool, length: str) -> None:
259+
self.available = available
260+
self.length = length
261+
262+
def retrieve_dataset_info(loader: DataLoader) -> dataset_info:
263+
"""Helper function to compute dataset information."""
264+
dataset = loader.dataset
265+
size: str = str(len(dataset)) if isinstance(dataset, Sized) else "NA"
266+
267+
return dataset_info(True, size)
268+
269+
def loader_info(
270+
loader: Union[DataLoader, Iterable[DataLoader]],
271+
) -> Union[dataset_info, Iterable[dataset_info]]:
272+
"""Helper function to compute dataset information."""
273+
return apply_to_collection(loader, DataLoader, retrieve_dataset_info)
274+
275+
def extract_loader_info(methods: list[tuple[str, str]]) -> dict:
276+
"""Helper function to extract information for each dataloader method."""
277+
info: dict[str, Union[dataset_info, Iterable[dataset_info]]] = {}
278+
for loader_name, func_name in methods:
279+
loader_method = getattr(self, func_name, None)
280+
281+
try:
282+
loader = loader_method() # type: ignore
283+
info[loader_name] = loader_info(loader)
284+
except Exception:
285+
info[loader_name] = dataset_info(False, "")
286+
287+
return info
288+
289+
def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info]]]) -> str:
290+
"""Helper function to format loader information."""
291+
output = []
292+
for loader_name, loader_info in info.items():
293+
# Single dataset
294+
if isinstance(loader_info, dataset_info):
295+
loader_info_formatted = "None" if not loader_info.available else f"size={loader_info.length}"
296+
# Iterable of datasets
297+
else:
298+
loader_info_formatted = " ; ".join(
299+
"None" if not loader_info_i.available else f"{i}. size={loader_info_i.length}"
300+
for i, loader_info_i in enumerate(loader_info, start=1)
301+
)
302+
303+
output.append(f"{{{loader_name}: {loader_info_formatted}}}")
304+
305+
return os.linesep.join(output)
306+
307+
# Available dataloader methods
308+
datamodule_loader_methods: list[tuple[str, str]] = [
309+
("Train dataloader", "train_dataloader"),
310+
("Validation dataloader", "val_dataloader"),
311+
("Test dataloader", "test_dataloader"),
312+
("Predict dataloader", "predict_dataloader"),
313+
]
314+
315+
# Retrieve information for each dataloader method
316+
dataloader_info = extract_loader_info(datamodule_loader_methods)
317+
# Format the information
318+
dataloader_str = format_loader_info(dataloader_info)
319+
return dataloader_str

src/lightning/pytorch/demos/boring_classes.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from collections.abc import Iterator
14+
from collections.abc import Iterable, Iterator
1515
from typing import Any, Optional
1616

1717
import torch
1818
import torch.nn as nn
1919
import torch.nn.functional as F
20+
from lightning_utilities import apply_to_collection
2021
from torch import Tensor
2122
from torch.optim import Optimizer
2223
from torch.optim.lr_scheduler import LRScheduler
@@ -188,6 +189,86 @@ def predict_dataloader(self) -> DataLoader:
188189
return DataLoader(self.random_predict)
189190

190191

192+
class BoringDataModuleNoLen(LightningDataModule):
193+
"""
194+
.. warning:: This is meant for testing/debugging and is experimental.
195+
"""
196+
197+
def __init__(self) -> None:
198+
super().__init__()
199+
200+
def setup(self, stage: str) -> None:
201+
if stage == "fit":
202+
self.random_train = RandomIterableDataset(32, 512)
203+
204+
if stage in ("fit", "validate"):
205+
self.random_val = RandomIterableDataset(32, 128)
206+
207+
if stage == "test":
208+
self.random_test = RandomIterableDataset(32, 256)
209+
210+
if stage == "predict":
211+
self.random_predict = RandomIterableDataset(32, 64)
212+
213+
def train_dataloader(self) -> DataLoader:
214+
return DataLoader(self.random_train)
215+
216+
def val_dataloader(self) -> DataLoader:
217+
return DataLoader(self.random_val)
218+
219+
def test_dataloader(self) -> DataLoader:
220+
return DataLoader(self.random_test)
221+
222+
def predict_dataloader(self) -> DataLoader:
223+
return DataLoader(self.random_predict)
224+
225+
226+
class IterableBoringDataModule(LightningDataModule):
227+
def __init__(self) -> None:
228+
super().__init__()
229+
230+
def setup(self, stage: str) -> None:
231+
if stage == "fit":
232+
self.train_datasets = [
233+
RandomDataset(4, 16),
234+
RandomIterableDataset(4, 16),
235+
]
236+
237+
if stage in ("fit", "validate"):
238+
self.val_datasets = [
239+
RandomDataset(4, 32),
240+
RandomIterableDataset(4, 32),
241+
]
242+
243+
if stage == "test":
244+
self.test_datasets = [
245+
RandomDataset(4, 64),
246+
RandomIterableDataset(4, 64),
247+
]
248+
249+
if stage == "predict":
250+
self.predict_datasets = [
251+
RandomDataset(4, 128),
252+
RandomIterableDataset(4, 128),
253+
]
254+
255+
def train_dataloader(self) -> Iterable[DataLoader]:
256+
combined_train = apply_to_collection(self.train_datasets, Dataset, lambda x: DataLoader(x))
257+
return combined_train
258+
259+
def val_dataloader(self) -> DataLoader:
260+
combined_val = apply_to_collection(self.val_datasets, Dataset, lambda x: DataLoader(x))
261+
return combined_val
262+
263+
def test_dataloader(self) -> DataLoader:
264+
combined_test = apply_to_collection(self.test_datasets, Dataset, lambda x: DataLoader(x))
265+
return combined_test
266+
267+
def predict_dataloader(self) -> DataLoader:
268+
combined_predict = apply_to_collection(self.predict_datasets, Dataset, lambda x: DataLoader(x))
269+
return combined_predict
270+
271+
191272
class ManualOptimBoringModel(BoringModel):
192273
"""
193274
.. warning:: This is meant for testing/debugging and is experimental.

src/lightning/pytorch/loops/prediction_loop.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,9 @@ def _predict_step(
233233

234234
self.batch_progress.increment_ready()
235235

236-
if not using_dataloader_iter:
237-
any_on_epoch = self._store_data_for_prediction_writer(batch_idx, dataloader_idx)
236+
any_on_epoch = (
237+
self._store_data_for_prediction_writer(batch_idx, dataloader_idx) if not using_dataloader_iter else False
238+
)
238239

239240
# the `_step` methods don't take a batch_idx when `dataloader_iter` is used, but all other hooks still do,
240241
# so we need different kwargs

src/lightning/pytorch/plugins/precision/fsdp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
from lightning_utilities import apply_to_collection
1919
from torch import Tensor
20+
from torch.nn import Module
2021
from typing_extensions import get_args, override
2122

2223
import lightning.pytorch as pl
@@ -73,6 +74,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
7374
}
7475
self._desired_input_dtype = precision_to_type[self.precision]
7576

77+
@override
78+
def convert_module(self, module: Module) -> Module:
79+
if "true" in self.precision:
80+
return module.to(dtype=self._desired_input_dtype)
81+
return module
82+
7683
@override
7784
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
7885
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_

tests/tests_fabric/plugins/precision/test_fsdp.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,21 @@ def test_invalid_precision_with_fsdp_precision():
127127

128128
with pytest.raises(ValueError, match="is not supported in FSDP. `precision` must be one of"):
129129
FSDPPrecision(precision="64-true")
130+
131+
132+
@pytest.mark.parametrize(
133+
("precision", "expected_dtype"),
134+
[
135+
("32-true", torch.float32),
136+
("bf16-mixed", torch.float32),
137+
("16-mixed", torch.float32),
138+
("bf16-true", torch.bfloat16),
139+
("16-true", torch.float16),
140+
],
141+
)
142+
def test_convert_module(precision, expected_dtype):
143+
precision = FSDPPrecision(precision=precision)
144+
module = torch.nn.Linear(2, 2)
145+
assert module.weight.dtype == module.bias.dtype == torch.float32
146+
module = precision.convert_module(module)
147+
assert module.weight.dtype == module.bias.dtype == expected_dtype

0 commit comments

Comments
 (0)