Skip to content

Commit f9e7214

Browse files
deependujhaIruos8805
authored andcommitted
Fix FSDP mixed precision semantics and add user warning (Lightning-AI#21361)
1 parent 44bf04c commit f9e7214

File tree

9 files changed

+69
-54
lines changed

9 files changed

+69
-54
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export PACKAGE_NAME=pytorch
1010

1111
# In Lightning Studio, the `lightning` package comes pre-installed.
1212
# Uninstall it first to ensure the editable install works correctly.
13-
setup:
13+
setup: update
1414
uv pip uninstall lightning pytorch-lightning lightning-fabric || true
1515
uv pip install -r requirements.txt \
1616
-r requirements/pytorch/base.txt \

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)).
3131

3232

33+
- Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361))
34+
35+
3336
---
3437

3538
## [2.5.5] - 2025-09-05

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
2525
from lightning.fabric.plugins.precision.precision import Precision
2626
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
27+
from lightning.fabric.utilities import rank_zero_warn
2728
from lightning.fabric.utilities.types import Optimizable
2829

2930
if TYPE_CHECKING:
@@ -84,19 +85,18 @@ def convert_module(self, module: Module) -> Module:
8485
def mixed_precision_config(self) -> "TorchMixedPrecision":
8586
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
8687

87-
if self.precision == "16-mixed":
88-
param_dtype = torch.float32
89-
reduce_dtype = buffer_dtype = torch.float16
90-
elif self.precision == "bf16-mixed":
91-
param_dtype = torch.float32
92-
reduce_dtype = buffer_dtype = torch.bfloat16
93-
elif self.precision == "16-true":
88+
if self.precision in ("16-true", "bf16-true"):
89+
rank_zero_warn(
90+
f"FSDP with `{self.precision}` enables computation in lower precision. "
91+
"FSDP will always retain a full-precision copy of the model parameters for sharding."
92+
)
93+
94+
if self.precision in ("16-true", "16-mixed"):
9495
param_dtype = reduce_dtype = buffer_dtype = torch.float16
95-
elif self.precision == "bf16-true":
96+
elif self.precision in ("bf16-true", "bf16-mixed"):
9697
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
9798
elif self.precision == "32-true":
98-
param_dtype = torch.float32
99-
reduce_dtype = buffer_dtype = torch.float32
99+
param_dtype = reduce_dtype = buffer_dtype = torch.float32
100100
else:
101101
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")
102102

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7979
- Fixed synchronization of gradients in manual optimization with `DDPStrategy(static_graph=True)` ([#21251](https://github.com/Lightning-AI/pytorch-lightning/pull/21251))
8080

8181

82+
- Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361))
83+
8284

8385
---
8486

src/lightning/pytorch/plugins/precision/fsdp.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
2525
from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT
2626
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
27+
from lightning.fabric.utilities import rank_zero_warn
2728
from lightning.fabric.utilities.types import Optimizable
2829
from lightning.pytorch.plugins.precision.precision import Precision
2930
from lightning.pytorch.utilities.exceptions import MisconfigurationException
@@ -94,19 +95,18 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
9495
def mixed_precision_config(self) -> "TorchMixedPrecision":
9596
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
9697

97-
if self.precision == "16-mixed":
98-
param_dtype = torch.float32
99-
reduce_dtype = buffer_dtype = torch.float16
100-
elif self.precision == "bf16-mixed":
101-
param_dtype = torch.float32
102-
reduce_dtype = buffer_dtype = torch.bfloat16
103-
elif self.precision == "16-true":
98+
if self.precision in ("16-true", "bf16-true"):
99+
rank_zero_warn(
100+
f"FSDP with `{self.precision}` enables computation in lower precision. "
101+
"FSDP will always retain a full-precision copy of the model parameters for sharding."
102+
)
103+
104+
if self.precision in ("16-true", "16-mixed"):
104105
param_dtype = reduce_dtype = buffer_dtype = torch.float16
105-
elif self.precision == "bf16-true":
106+
elif self.precision in ("bf16-true", "bf16-mixed"):
106107
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
107108
elif self.precision == "32-true":
108-
param_dtype = torch.float32
109-
reduce_dtype = buffer_dtype = torch.float32
109+
param_dtype = reduce_dtype = buffer_dtype = torch.float32
110110
else:
111111
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
112112

tests/tests_fabric/plugins/precision/test_fsdp.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from contextlib import contextmanager
1415
from unittest.mock import Mock
1516

1617
import pytest
@@ -21,19 +22,30 @@
2122
from tests_fabric.helpers.runif import RunIf
2223

2324

25+
# Pytest passes args/kwargs to the context manager used with `pytest.warns`.
26+
# `contextlib.nullcontext` doesn't accept them, so this no-op version does.
27+
@contextmanager
28+
def null_ctx(*args, **kwargs):
29+
yield
30+
31+
2432
@pytest.mark.parametrize(
2533
("precision", "expected"),
2634
[
2735
("16-true", (torch.float16, torch.float16, torch.float16)),
2836
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
29-
("16-mixed", (torch.float32, torch.float16, torch.float16)),
30-
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
37+
("16-mixed", (torch.float16, torch.float16, torch.float16)),
38+
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
3139
("32-true", (torch.float32, torch.float32, torch.float32)),
3240
],
3341
)
3442
def test_fsdp_precision_config(precision, expected):
3543
plugin = FSDPPrecision(precision=precision)
36-
config = plugin.mixed_precision_config
44+
45+
warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx
46+
47+
with warning_ctx(UserWarning, match="enables computation in lower precision"):
48+
config = plugin.mixed_precision_config
3749

3850
assert config.param_dtype == expected[0]
3951
assert config.buffer_dtype == expected[1]

tests/tests_fabric/strategies/test_fsdp_integration.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,9 @@ def step(self, model, batch):
8787

8888
precision = self.fabric._precision
8989
assert isinstance(precision, FSDPPrecision)
90-
if precision.precision == "16-mixed":
91-
param_dtype = torch.float32
92-
reduce_dtype = buffer_dtype = torch.float16
93-
elif precision.precision == "bf16-mixed":
94-
param_dtype = torch.float32
95-
reduce_dtype = buffer_dtype = torch.bfloat16
96-
elif precision.precision == "16-true":
90+
if precision.precision in ("16-true", "16-mixed"):
9791
param_dtype = reduce_dtype = buffer_dtype = torch.float16
98-
elif precision.precision == "bf16-true":
92+
elif precision.precision in ("bf16-true", "bf16-mixed"):
9993
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
10094
else:
10195
raise ValueError(f"Unknown precision {precision.precision}")

tests/tests_pytorch/plugins/precision/test_fsdp.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from contextlib import contextmanager
1415
from unittest.mock import ANY, MagicMock, Mock
1516

1617
import pytest
@@ -21,19 +22,30 @@
2122
from tests_pytorch.helpers.runif import RunIf
2223

2324

25+
# Pytest passes args/kwargs to the context manager used with `pytest.warns`.
26+
# `contextlib.nullcontext` doesn't accept them, so this no-op version does.
27+
@contextmanager
28+
def null_ctx(*args, **kwargs):
29+
yield
30+
31+
2432
@pytest.mark.parametrize(
2533
("precision", "expected"),
2634
[
2735
("16-true", (torch.float16, torch.float16, torch.float16)),
2836
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
29-
("16-mixed", (torch.float32, torch.float16, torch.float16)),
30-
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
37+
("16-mixed", (torch.float16, torch.float16, torch.float16)),
38+
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
3139
("32-true", (torch.float32, torch.float32, torch.float32)),
3240
],
3341
)
3442
def test_fsdp_precision_config(precision, expected):
3543
plugin = FSDPPrecision(precision=precision)
36-
config = plugin.mixed_precision_config
44+
45+
warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx
46+
47+
with warning_ctx(UserWarning, match="enables computation in lower precision"):
48+
config = plugin.mixed_precision_config
3749

3850
assert config.param_dtype == expected[0]
3951
assert config.buffer_dtype == expected[1]

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,12 @@ def _assert_layer_fsdp_instance(self) -> None:
7777
assert isinstance(self.layer, FullyShardedDataParallel)
7878
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)
7979

80-
if self.trainer.precision == "16-mixed":
81-
param_dtype = torch.float32
82-
reduce_dtype = buffer_dtype = torch.float16
83-
elif self.trainer.precision == "bf16-mixed":
84-
param_dtype = torch.float32
85-
reduce_dtype = buffer_dtype = torch.bfloat16
86-
elif self.trainer.precision == "16-true":
80+
if self.trainer.precision in ("16-true", "16-mixed"):
8781
param_dtype = reduce_dtype = buffer_dtype = torch.float16
88-
elif self.trainer.precision == "bf16-true":
82+
elif self.trainer.precision in ("bf16-true", "bf16-mixed"):
8983
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
84+
elif self.trainer.precision == "32-true":
85+
param_dtype = reduce_dtype = buffer_dtype = torch.float32
9086
else:
9187
raise ValueError(f"Unknown precision {self.trainer.precision}")
9288

@@ -138,16 +134,12 @@ def _assert_layer_fsdp_instance(self) -> None:
138134
assert isinstance(self.layer, torch.nn.Sequential)
139135
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)
140136

141-
if self.trainer.precision == "16-mixed":
142-
param_dtype = torch.float32
143-
reduce_dtype = buffer_dtype = torch.float16
144-
elif self.trainer.precision == "bf16-mixed":
145-
param_dtype = torch.float32
146-
reduce_dtype = buffer_dtype = torch.bfloat16
147-
elif self.trainer.precision == "16-true":
137+
if self.trainer.precision in ("16-true", "16-mixed"):
148138
param_dtype = reduce_dtype = buffer_dtype = torch.float16
149-
elif self.trainer.precision == "bf16-true":
139+
elif self.trainer.precision in ("bf16-true", "bf16-mixed"):
150140
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
141+
elif self.trainer.precision == "32-true":
142+
param_dtype = reduce_dtype = buffer_dtype = torch.float32
151143
else:
152144
raise ValueError(f"Unknown precision {self.trainer.precision}")
153145

@@ -227,7 +219,7 @@ def test_strategy_sync_batchnorm(tmp_path):
227219
accelerator="gpu",
228220
devices=2,
229221
strategy="fsdp",
230-
precision="16-mixed",
222+
precision="32-true",
231223
max_epochs=1,
232224
sync_batchnorm=True,
233225
)
@@ -267,7 +259,7 @@ def training_step(self, batch, batch_idx):
267259

268260
@pytest.mark.filterwarnings("ignore::FutureWarning")
269261
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
270-
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
262+
@pytest.mark.parametrize("precision", ["32-true", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
271263
@pytest.mark.parametrize("state_dict_type", ["sharded", "full"])
272264
def test_strategy_checkpoint(state_dict_type, precision, tmp_path):
273265
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
@@ -359,7 +351,7 @@ def test_checkpoint_multi_gpus(tmp_path, model, strategy, strategy_cfg):
359351
accelerator="gpu",
360352
devices=2,
361353
strategy=strategy,
362-
precision="16-mixed",
354+
precision="32-true",
363355
max_epochs=1,
364356
limit_train_batches=2,
365357
limit_val_batches=2,

0 commit comments

Comments
 (0)