diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 0b78ad72a441f..270a67e3a2338 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -74,6 +74,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca } self._desired_input_dtype = precision_to_type[self.precision] + @override + def convert_module(self, module: Module) -> Module: + if "true" in self.precision: + return module.to(dtype=self._desired_input_dtype) + return module + @property def mixed_precision_config(self) -> "TorchMixedPrecision": from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1b251b8fb06fa..18dc8891144ad 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814)) - Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019)) - Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163)) +- Fixed PyTorch Lightning FSDP takes more memory than PyTorch FSDP ([#20323](https://github.com/Lightning-AI/pytorch-lightning/pull/20323)) ## [2.3.0] - 2024-06-13 diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 7029497c177cc..f3bab3e915e91 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -17,6 +17,7 @@ import torch from lightning_utilities import apply_to_collection from torch import Tensor +from torch.nn import Module from typing_extensions import get_args, override import lightning.pytorch as pl @@ -73,6 +74,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca } self._desired_input_dtype = precision_to_type[self.precision] + @override + def convert_module(self, module: Module) -> Module: + if "true" in self.precision: + return module.to(dtype=self._desired_input_dtype) + return module + @override def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ diff --git a/tests/tests_fabric/plugins/precision/test_fsdp.py b/tests/tests_fabric/plugins/precision/test_fsdp.py index e42df493dd725..6a4968736ea86 100644 --- a/tests/tests_fabric/plugins/precision/test_fsdp.py +++ b/tests/tests_fabric/plugins/precision/test_fsdp.py @@ -127,3 +127,21 @@ def test_invalid_precision_with_fsdp_precision(): with pytest.raises(ValueError, match="is not supported in FSDP. `precision` must be one of"): FSDPPrecision(precision="64-true") + + +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("32-true", torch.float32), + ("bf16-mixed", torch.float32), + ("16-mixed", torch.float32), + ("bf16-true", torch.bfloat16), + ("16-true", torch.float16), + ], +) +def test_convert_module(precision, expected_dtype): + precision = FSDPPrecision(precision=precision) + module = torch.nn.Linear(2, 2) + assert module.weight.dtype == module.bias.dtype == torch.float32 + module = precision.convert_module(module) + assert module.weight.dtype == module.bias.dtype == expected_dtype diff --git a/tests/tests_pytorch/plugins/precision/test_fsdp.py b/tests/tests_pytorch/plugins/precision/test_fsdp.py index 8b595c2c74a32..3ad3af1f1b56b 100644 --- a/tests/tests_pytorch/plugins/precision/test_fsdp.py +++ b/tests/tests_pytorch/plugins/precision/test_fsdp.py @@ -40,6 +40,24 @@ def test_fsdp_precision_config(precision, expected): assert config.reduce_dtype == expected[2] +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("32-true", torch.float32), + ("bf16-mixed", torch.float32), + ("16-mixed", torch.float32), + ("bf16-true", torch.bfloat16), + ("16-true", torch.float16), + ], +) +def test_convert_module(precision, expected_dtype): + precision = FSDPPrecision(precision=precision) + module = torch.nn.Linear(2, 2) + assert module.weight.dtype == module.bias.dtype == torch.float32 + module = precision.convert_module(module) + assert module.weight.dtype == module.bias.dtype == expected_dtype + + def test_fsdp_precision_default_scaler(): from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler