diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 270a67e3a2338..189135e7b19e8 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -24,6 +24,7 @@ from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager +from lightning.fabric.utilities import rank_zero_warn from lightning.fabric.utilities.types import Optimizable if TYPE_CHECKING: @@ -84,19 +85,18 @@ def convert_module(self, module: Module) -> Module: def mixed_precision_config(self) -> "TorchMixedPrecision": from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision - if self.precision == "16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float16 - elif self.precision == "bf16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 - elif self.precision == "16-true": + if self.precision in ("16-true", "bf16-true"): + rank_zero_warn( + f"FSDP with `{self.precision}` enables computation in lower precision. " + "FSDP will always retain a full-precision copy of the model parameters for sharding." + ) + + if self.precision in ("16-true", "16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.float16 - elif self.precision == "bf16-true": + elif self.precision in ("bf16-true", "bf16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 elif self.precision == "32-true": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float32 + param_dtype = reduce_dtype = buffer_dtype = torch.float32 else: raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.") diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index f3bab3e915e91..337c6a465278d 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -24,6 +24,7 @@ from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager +from lightning.fabric.utilities import rank_zero_warn from lightning.fabric.utilities.types import Optimizable from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -94,19 +95,18 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: def mixed_precision_config(self) -> "TorchMixedPrecision": from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision - if self.precision == "16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float16 - elif self.precision == "bf16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 - elif self.precision == "16-true": + if self.precision in ("16-true", "bf16-true"): + rank_zero_warn( + f"FSDP with `{self.precision}` enables computation in lower precision. " + "FSDP will always retain a full-precision copy of the model parameters for sharding." + ) + + if self.precision in ("16-true", "16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.float16 - elif self.precision == "bf16-true": + elif self.precision in ("bf16-true", "bf16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 elif self.precision == "32-true": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float32 + param_dtype = reduce_dtype = buffer_dtype = torch.float32 else: raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.") diff --git a/tests/tests_fabric/plugins/precision/test_fsdp.py b/tests/tests_fabric/plugins/precision/test_fsdp.py index b15e8e6c65f57..7507002ab4fd1 100644 --- a/tests/tests_fabric/plugins/precision/test_fsdp.py +++ b/tests/tests_fabric/plugins/precision/test_fsdp.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager from unittest.mock import Mock import pytest @@ -21,19 +22,30 @@ from tests_fabric.helpers.runif import RunIf +# Pytest passes args/kwargs to the context manager used with `pytest.warns`. +# `contextlib.nullcontext` doesn't accept them, so this no-op version does. +@contextmanager +def null_ctx(*args, **kwargs): + yield + + @pytest.mark.parametrize( ("precision", "expected"), [ ("16-true", (torch.float16, torch.float16, torch.float16)), ("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), - ("16-mixed", (torch.float32, torch.float16, torch.float16)), - ("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)), + ("16-mixed", (torch.float16, torch.float16, torch.float16)), + ("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), ("32-true", (torch.float32, torch.float32, torch.float32)), ], ) def test_fsdp_precision_config(precision, expected): plugin = FSDPPrecision(precision=precision) - config = plugin.mixed_precision_config + + warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx + + with warning_ctx(UserWarning, match="enables computation in lower precision"): + config = plugin.mixed_precision_config assert config.param_dtype == expected[0] assert config.buffer_dtype == expected[1] diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 5da9b50399a94..532f0f9b8ca94 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -87,15 +87,9 @@ def step(self, model, batch): precision = self.fabric._precision assert isinstance(precision, FSDPPrecision) - if precision.precision == "16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float16 - elif precision.precision == "bf16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 - elif precision.precision == "16-true": + if precision.precision in ("16-true", "16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.float16 - elif precision.precision == "bf16-true": + elif precision.precision in ("bf16-true", "bf16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 else: raise ValueError(f"Unknown precision {precision.precision}") diff --git a/tests/tests_pytorch/plugins/precision/test_fsdp.py b/tests/tests_pytorch/plugins/precision/test_fsdp.py index f8731aa424b38..0834ef1f98400 100644 --- a/tests/tests_pytorch/plugins/precision/test_fsdp.py +++ b/tests/tests_pytorch/plugins/precision/test_fsdp.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager from unittest.mock import ANY, MagicMock, Mock import pytest @@ -21,19 +22,30 @@ from tests_pytorch.helpers.runif import RunIf +# Pytest passes args/kwargs to the context manager used with `pytest.warns`. +# `contextlib.nullcontext` doesn't accept them, so this no-op version does. +@contextmanager +def null_ctx(*args, **kwargs): + yield + + @pytest.mark.parametrize( ("precision", "expected"), [ ("16-true", (torch.float16, torch.float16, torch.float16)), ("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), - ("16-mixed", (torch.float32, torch.float16, torch.float16)), - ("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)), + ("16-mixed", (torch.float16, torch.float16, torch.float16)), + ("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)), ("32-true", (torch.float32, torch.float32, torch.float32)), ], ) def test_fsdp_precision_config(precision, expected): plugin = FSDPPrecision(precision=precision) - config = plugin.mixed_precision_config + + warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx + + with warning_ctx(UserWarning, match="enables computation in lower precision"): + config = plugin.mixed_precision_config assert config.param_dtype == expected[0] assert config.buffer_dtype == expected[1] diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index f7c15b5930be8..8fd60d84d61a1 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -77,15 +77,9 @@ def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer, FullyShardedDataParallel) assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision) - if self.trainer.precision == "16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float16 - elif self.trainer.precision == "bf16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 - elif self.trainer.precision == "16-true": + if self.trainer.precision in ("16-true", "16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.float16 - elif self.trainer.precision == "bf16-true": + elif self.trainer.precision in ("bf16-true", "bf16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 else: raise ValueError(f"Unknown precision {self.trainer.precision}") @@ -138,15 +132,9 @@ def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer, torch.nn.Sequential) assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision) - if self.trainer.precision == "16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.float16 - elif self.trainer.precision == "bf16-mixed": - param_dtype = torch.float32 - reduce_dtype = buffer_dtype = torch.bfloat16 - elif self.trainer.precision == "16-true": + if self.trainer.precision in ("16-true", "16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.float16 - elif self.trainer.precision == "bf16-true": + elif self.trainer.precision in ("bf16-true", "bf16-mixed"): param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 else: raise ValueError(f"Unknown precision {self.trainer.precision}")