Skip to content

Commit df3eb79

Browse files
authored
Merge branch 'master' into update-pyproject-py310
2 parents 2df182c + eec92a9 commit df3eb79

File tree

10 files changed

+72
-68
lines changed

10 files changed

+72
-68
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export PACKAGE_NAME=pytorch
1010

1111
# In Lightning Studio, the `lightning` package comes pre-installed.
1212
# Uninstall it first to ensure the editable install works correctly.
13-
setup:
13+
setup: update
1414
uv pip uninstall lightning pytorch-lightning lightning-fabric || true
1515
uv pip install -r requirements.txt \
1616
-r requirements/pytorch/base.txt \

src/lightning/fabric/CHANGELOG.md

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,7 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7-
8-
## [unreleased] - YYYY-MM-DD
9-
10-
### Added
11-
12-
-
13-
14-
15-
### Removed
16-
17-
-
18-
7+
## [2.6.0] - 2025-11-21
198

209
### Changed
2110

@@ -30,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3019
- Learning rate scheduler is stepped at the end of epoch when `on_train_batch_start` returns -1 ([#21296](https://github.com/Lightning-AI/pytorch-lightning/issues/21296)).
3120

3221

22+
- Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361))
23+
24+
3325
---
3426

3527
## [2.5.5] - 2025-09-05

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
2525
from lightning.fabric.plugins.precision.precision import Precision
2626
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
27+
from lightning.fabric.utilities import rank_zero_warn
2728
from lightning.fabric.utilities.types import Optimizable
2829

2930
if TYPE_CHECKING:
@@ -84,19 +85,18 @@ def convert_module(self, module: Module) -> Module:
8485
def mixed_precision_config(self) -> "TorchMixedPrecision":
8586
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
8687

87-
if self.precision == "16-mixed":
88-
param_dtype = torch.float32
89-
reduce_dtype = buffer_dtype = torch.float16
90-
elif self.precision == "bf16-mixed":
91-
param_dtype = torch.float32
92-
reduce_dtype = buffer_dtype = torch.bfloat16
93-
elif self.precision == "16-true":
88+
if self.precision in ("16-true", "bf16-true"):
89+
rank_zero_warn(
90+
f"FSDP with `{self.precision}` enables computation in lower precision. "
91+
"FSDP will always retain a full-precision copy of the model parameters for sharding."
92+
)
93+
94+
if self.precision in ("16-true", "16-mixed"):
9495
param_dtype = reduce_dtype = buffer_dtype = torch.float16
95-
elif self.precision == "bf16-true":
96+
elif self.precision in ("bf16-true", "bf16-mixed"):
9697
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
9798
elif self.precision == "32-true":
98-
param_dtype = torch.float32
99-
reduce_dtype = buffer_dtype = torch.float32
99+
param_dtype = reduce_dtype = buffer_dtype = torch.float32
100100
else:
101101
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")
102102

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77
---
88

9-
## [unreleased] - YYYY-MM-DD
9+
## [2.6.0] - 2025-11-21
1010

1111
### Added
1212

@@ -79,6 +79,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7979
- Fixed synchronization of gradients in manual optimization with `DDPStrategy(static_graph=True)` ([#21251](https://github.com/Lightning-AI/pytorch-lightning/pull/21251))
8080

8181

82+
- Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361))
83+
8284

8385
---
8486

src/lightning/pytorch/plugins/precision/fsdp.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
2626
from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT
2727
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
28+
from lightning.fabric.utilities import rank_zero_warn
2829
from lightning.fabric.utilities.types import Optimizable
2930
from lightning.pytorch.plugins.precision.precision import Precision
3031
from lightning.pytorch.utilities.exceptions import MisconfigurationException
@@ -95,19 +96,18 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
9596
def mixed_precision_config(self) -> "TorchMixedPrecision":
9697
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
9798

98-
if self.precision == "16-mixed":
99-
param_dtype = torch.float32
100-
reduce_dtype = buffer_dtype = torch.float16
101-
elif self.precision == "bf16-mixed":
102-
param_dtype = torch.float32
103-
reduce_dtype = buffer_dtype = torch.bfloat16
104-
elif self.precision == "16-true":
99+
if self.precision in ("16-true", "bf16-true"):
100+
rank_zero_warn(
101+
f"FSDP with `{self.precision}` enables computation in lower precision. "
102+
"FSDP will always retain a full-precision copy of the model parameters for sharding."
103+
)
104+
105+
if self.precision in ("16-true", "16-mixed"):
105106
param_dtype = reduce_dtype = buffer_dtype = torch.float16
106-
elif self.precision == "bf16-true":
107+
elif self.precision in ("bf16-true", "bf16-mixed"):
107108
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
108109
elif self.precision == "32-true":
109-
param_dtype = torch.float32
110-
reduce_dtype = buffer_dtype = torch.float32
110+
param_dtype = reduce_dtype = buffer_dtype = torch.float32
111111
else:
112112
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
113113

src/version.info

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.6.0dev0
1+
2.6.0

tests/tests_fabric/plugins/precision/test_fsdp.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from contextlib import contextmanager
1415
from unittest.mock import Mock
1516

1617
import pytest
@@ -21,19 +22,30 @@
2122
from tests_fabric.helpers.runif import RunIf
2223

2324

25+
# Pytest passes args/kwargs to the context manager used with `pytest.warns`.
26+
# `contextlib.nullcontext` doesn't accept them, so this no-op version does.
27+
@contextmanager
28+
def null_ctx(*args, **kwargs):
29+
yield
30+
31+
2432
@pytest.mark.parametrize(
2533
("precision", "expected"),
2634
[
2735
("16-true", (torch.float16, torch.float16, torch.float16)),
2836
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
29-
("16-mixed", (torch.float32, torch.float16, torch.float16)),
30-
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
37+
("16-mixed", (torch.float16, torch.float16, torch.float16)),
38+
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
3139
("32-true", (torch.float32, torch.float32, torch.float32)),
3240
],
3341
)
3442
def test_fsdp_precision_config(precision, expected):
3543
plugin = FSDPPrecision(precision=precision)
36-
config = plugin.mixed_precision_config
44+
45+
warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx
46+
47+
with warning_ctx(UserWarning, match="enables computation in lower precision"):
48+
config = plugin.mixed_precision_config
3749

3850
assert config.param_dtype == expected[0]
3951
assert config.buffer_dtype == expected[1]

tests/tests_fabric/strategies/test_fsdp_integration.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,9 @@ def step(self, model, batch):
8787

8888
precision = self.fabric._precision
8989
assert isinstance(precision, FSDPPrecision)
90-
if precision.precision == "16-mixed":
91-
param_dtype = torch.float32
92-
reduce_dtype = buffer_dtype = torch.float16
93-
elif precision.precision == "bf16-mixed":
94-
param_dtype = torch.float32
95-
reduce_dtype = buffer_dtype = torch.bfloat16
96-
elif precision.precision == "16-true":
90+
if precision.precision in ("16-true", "16-mixed"):
9791
param_dtype = reduce_dtype = buffer_dtype = torch.float16
98-
elif precision.precision == "bf16-true":
92+
elif precision.precision in ("bf16-true", "bf16-mixed"):
9993
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
10094
else:
10195
raise ValueError(f"Unknown precision {precision.precision}")

tests/tests_pytorch/plugins/precision/test_fsdp.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from contextlib import contextmanager
1415
from unittest.mock import ANY, MagicMock, Mock
1516

1617
import pytest
@@ -21,19 +22,30 @@
2122
from tests_pytorch.helpers.runif import RunIf
2223

2324

25+
# Pytest passes args/kwargs to the context manager used with `pytest.warns`.
26+
# `contextlib.nullcontext` doesn't accept them, so this no-op version does.
27+
@contextmanager
28+
def null_ctx(*args, **kwargs):
29+
yield
30+
31+
2432
@pytest.mark.parametrize(
2533
("precision", "expected"),
2634
[
2735
("16-true", (torch.float16, torch.float16, torch.float16)),
2836
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
29-
("16-mixed", (torch.float32, torch.float16, torch.float16)),
30-
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
37+
("16-mixed", (torch.float16, torch.float16, torch.float16)),
38+
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
3139
("32-true", (torch.float32, torch.float32, torch.float32)),
3240
],
3341
)
3442
def test_fsdp_precision_config(precision, expected):
3543
plugin = FSDPPrecision(precision=precision)
36-
config = plugin.mixed_precision_config
44+
45+
warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx
46+
47+
with warning_ctx(UserWarning, match="enables computation in lower precision"):
48+
config = plugin.mixed_precision_config
3749

3850
assert config.param_dtype == expected[0]
3951
assert config.buffer_dtype == expected[1]

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,12 @@ def _assert_layer_fsdp_instance(self) -> None:
7676
assert isinstance(self.layer, FullyShardedDataParallel)
7777
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)
7878

79-
if self.trainer.precision == "16-mixed":
80-
param_dtype = torch.float32
81-
reduce_dtype = buffer_dtype = torch.float16
82-
elif self.trainer.precision == "bf16-mixed":
83-
param_dtype = torch.float32
84-
reduce_dtype = buffer_dtype = torch.bfloat16
85-
elif self.trainer.precision == "16-true":
79+
if self.trainer.precision in ("16-true", "16-mixed"):
8680
param_dtype = reduce_dtype = buffer_dtype = torch.float16
87-
elif self.trainer.precision == "bf16-true":
81+
elif self.trainer.precision in ("bf16-true", "bf16-mixed"):
8882
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
83+
elif self.trainer.precision == "32-true":
84+
param_dtype = reduce_dtype = buffer_dtype = torch.float32
8985
else:
9086
raise ValueError(f"Unknown precision {self.trainer.precision}")
9187

@@ -137,16 +133,12 @@ def _assert_layer_fsdp_instance(self) -> None:
137133
assert isinstance(self.layer, torch.nn.Sequential)
138134
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)
139135

140-
if self.trainer.precision == "16-mixed":
141-
param_dtype = torch.float32
142-
reduce_dtype = buffer_dtype = torch.float16
143-
elif self.trainer.precision == "bf16-mixed":
144-
param_dtype = torch.float32
145-
reduce_dtype = buffer_dtype = torch.bfloat16
146-
elif self.trainer.precision == "16-true":
136+
if self.trainer.precision in ("16-true", "16-mixed"):
147137
param_dtype = reduce_dtype = buffer_dtype = torch.float16
148-
elif self.trainer.precision == "bf16-true":
138+
elif self.trainer.precision in ("bf16-true", "bf16-mixed"):
149139
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
140+
elif self.trainer.precision == "32-true":
141+
param_dtype = reduce_dtype = buffer_dtype = torch.float32
150142
else:
151143
raise ValueError(f"Unknown precision {self.trainer.precision}")
152144

@@ -226,7 +218,7 @@ def test_strategy_sync_batchnorm(tmp_path):
226218
accelerator="gpu",
227219
devices=2,
228220
strategy="fsdp",
229-
precision="16-mixed",
221+
precision="32-true",
230222
max_epochs=1,
231223
sync_batchnorm=True,
232224
)
@@ -266,7 +258,7 @@ def training_step(self, batch, batch_idx):
266258

267259
@pytest.mark.filterwarnings("ignore::FutureWarning")
268260
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
269-
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
261+
@pytest.mark.parametrize("precision", ["32-true", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
270262
@pytest.mark.parametrize("state_dict_type", ["sharded", "full"])
271263
def test_strategy_checkpoint(state_dict_type, precision, tmp_path):
272264
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
@@ -358,7 +350,7 @@ def test_checkpoint_multi_gpus(tmp_path, model, strategy, strategy_cfg):
358350
accelerator="gpu",
359351
devices=2,
360352
strategy=strategy,
361-
precision="16-mixed",
353+
precision="32-true",
362354
max_epochs=1,
363355
limit_train_batches=2,
364356
limit_val_batches=2,

0 commit comments

Comments
 (0)