Skip to content

Commit 8726af4

Browse files
speediedanlantiga
authored andcommitted
Accommodate FSDP full-precision param_dtype training with PyTorch < 2.0 (#18278)
(cherry picked from commit c081b48)
1 parent 21563f8 commit 8726af4

File tree

7 files changed

+38
-15
lines changed

7 files changed

+38
-15
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030

3131
### Fixed
3232

33+
- Fixed FSDP full-precision `param_dtype` training (`16-mixed`, `bf16-mixed` and `32-true` configurations) to avoid FSDP assertion errors with PyTorch < 2.0 ([#18278](https://github.com/Lightning-AI/lightning/pull/18278))
34+
35+
3336
- Fixed issue where DDP subprocesses that used Hydra would set hydra's working directory to current directory ([#18145](https://github.com/Lightning-AI/lightning/pull/18145))
3437
- Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning ([#18177](https://github.com/Lightning-AI/lightning/pull/18177))
3538
- Fixed an issue with `Fabric.all_reduce()` not performing an inplace operation for all backends consistently ([#18235](https://github.com/Lightning-AI/lightning/pull/18235))

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717

1818
from lightning.fabric.plugins.precision.amp import MixedPrecision
19-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
19+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0
2020

2121
if TYPE_CHECKING:
2222
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
@@ -48,11 +48,14 @@ def __init__(
4848
def mixed_precision_config(self) -> "TorchMixedPrecision":
4949
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
5050

51+
# With PyTorch < 2.0, FSDP uses the noneness of `param_dtype` as a proxy for the `_uses_param_mixed_precision`
52+
# property. In order to avoid FSDP assertion failures, we therefore avoid setting `param_dtype` to
53+
# `torch.float32` here with PyTorch < 2.0.
5154
if self.precision == "16-mixed":
52-
param_dtype = torch.float32
55+
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
5356
reduce_dtype = buffer_dtype = torch.float16
5457
elif self.precision == "bf16-mixed":
55-
param_dtype = torch.float32
58+
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
5659
reduce_dtype = buffer_dtype = torch.bfloat16
5760
elif self.precision == "16-true":
5861
param_dtype = reduce_dtype = buffer_dtype = torch.float16

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@
3333
from lightning.fabric.strategies.registry import _StrategyRegistry
3434
from lightning.fabric.strategies.strategy import _Sharded
3535
from lightning.fabric.utilities.distributed import log
36-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
37-
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
36+
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
3837
from lightning.fabric.utilities.seed import reset_seed
3938
from lightning.fabric.utilities.types import _PATH
4039

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- On XLA, avoid setting the global rank before processes have been launched as this will initialize the PJRT computation client in the main process ([#16966](https://github.com/Lightning-AI/lightning/pull/16966))
1313

1414

15+
- Fixed FSDP full-precision `param_dtype` training (`16-mixed`, `bf16-mixed` and `32-true` configurations) to avoid FSDP assertion errors with PyTorch < 2.0 ([#18278](https://github.com/Lightning-AI/lightning/pull/18278))
16+
17+
1518
## [2.0.7] - 2023-08-14
1619

1720
### Added

src/lightning/pytorch/plugins/precision/fsdp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import torch
1717

18-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
18+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0
1919
from lightning.pytorch.plugins.precision.amp import MixedPrecisionPlugin
2020
from lightning.pytorch.utilities.exceptions import MisconfigurationException
2121

@@ -57,11 +57,14 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
5757
def mixed_precision_config(self) -> Optional[MixedPrecision]:
5858
assert MixedPrecision is not None
5959

60+
# With PyTorch < 2.0, FSDP uses the noneness of `param_dtype` as a proxy for the `_uses_param_mixed_precision`
61+
# property. In order to avoid FSDP assertion failures, we therefore avoid setting `param_dtype` to
62+
# `torch.float32` here with PyTorch < 2.0.
6063
if self.precision == "16-mixed":
61-
param_dtype = torch.float32
64+
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
6265
reduce_dtype = buffer_dtype = torch.float16
6366
elif self.precision == "bf16-mixed":
64-
param_dtype = torch.float32
67+
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
6568
reduce_dtype = buffer_dtype = torch.bfloat16
6669
elif self.precision == "16-true":
6770
param_dtype = reduce_dtype = buffer_dtype = torch.float16

tests/tests_fabric/plugins/precision/test_fsdp.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,21 @@ def test_fsdp_precision_support(*_):
3030
@pytest.mark.parametrize(
3131
("precision", "expected"),
3232
[
33-
("16-mixed", (torch.float32, torch.float16, torch.float16)),
34-
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
35-
# TODO: add 16-true and bf16-true once supported
33+
pytest.param(
34+
"16-mixed", (torch.float32, torch.float16, torch.float16), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0"
35+
),
36+
pytest.param(
37+
"16-mixed", (None, torch.float16, torch.float16), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0"
38+
),
39+
pytest.param(
40+
"bf16-mixed",
41+
(torch.float32, torch.bfloat16, torch.bfloat16),
42+
marks=RunIf(min_torch="2.0"),
43+
id="bf16-mixed-ge2_0",
44+
),
45+
pytest.param(
46+
"bf16-mixed", (None, torch.bfloat16, torch.bfloat16), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0"
47+
),
3648
],
3749
)
3850
def test_fsdp_precision_config(precision, expected):

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ def _assert_layer_fsdp_instance(self) -> None:
7474
assert isinstance(self.trainer.strategy.precision_plugin, FSDPMixedPrecisionPlugin)
7575

7676
if self.trainer.precision == "16-mixed":
77-
param_dtype = torch.float32
77+
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
7878
reduce_dtype = buffer_dtype = torch.float16
7979
elif self.trainer.precision == "bf16-mixed":
80-
param_dtype = torch.float32
80+
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
8181
reduce_dtype = buffer_dtype = torch.bfloat16
8282
elif self.trainer.precision == "16-true":
8383
param_dtype = reduce_dtype = buffer_dtype = torch.float16
@@ -122,10 +122,10 @@ def _assert_layer_fsdp_instance(self) -> None:
122122
assert isinstance(self.trainer.strategy.precision_plugin, FSDPMixedPrecisionPlugin)
123123

124124
if self.trainer.precision == "16-mixed":
125-
param_dtype = torch.float32
125+
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
126126
reduce_dtype = buffer_dtype = torch.float16
127127
elif self.trainer.precision == "bf16-mixed":
128-
param_dtype = torch.float32
128+
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
129129
reduce_dtype = buffer_dtype = torch.bfloat16
130130
elif self.trainer.precision == "16-true":
131131
param_dtype = reduce_dtype = buffer_dtype = torch.float16

0 commit comments

Comments
 (0)