Skip to content
21 changes: 11 additions & 10 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.fabric.utilities import rank_zero_warn
from lightning.fabric.utilities.types import Optimizable

if TYPE_CHECKING:
Expand Down Expand Up @@ -84,19 +85,19 @@ def convert_module(self, module: Module) -> Module:
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

if self.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
if self.precision in ("16-true", "bf16-true"):
rank_zero_warn(
f"FSDPPrecision `{self.precision}` enables mixed-precision execution. "
"Model parameters remain in full precision `torch.float32`, while forward and backward passes "
f"run with reduced precision `{self._desired_input_dtype}` for speed and memory efficiency."
)

if self.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
elif self.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "32-true":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float32
param_dtype = reduce_dtype = buffer_dtype = torch.float32
else:
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")

Expand Down
21 changes: 11 additions & 10 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.fabric.utilities import rank_zero_warn
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -94,19 +95,19 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

if self.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
if self.precision in ("16-true", "bf16-true"):
rank_zero_warn(
f"FSDPPrecision `{self.precision}` enables mixed-precision execution. "
"Model parameters remain in full precision `torch.float32`, while forward and backward passes "
f"run with reduced precision `{self._desired_input_dtype}` for speed and memory efficiency."
)

if self.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
elif self.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "32-true":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float32
param_dtype = reduce_dtype = buffer_dtype = torch.float32
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")

Expand Down
18 changes: 15 additions & 3 deletions tests/tests_fabric/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from unittest.mock import Mock

import pytest
Expand All @@ -21,19 +22,30 @@
from tests_fabric.helpers.runif import RunIf


# Pytest passes args/kwargs to the context manager used with `pytest.warns`.
# `contextlib.nullcontext` doesn't accept them, so this no-op version does.
@contextmanager
def null_ctx(*args, **kwargs):
yield


@pytest.mark.parametrize(
("precision", "expected"),
[
("16-true", (torch.float16, torch.float16, torch.float16)),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
("16-mixed", (torch.float32, torch.float16, torch.float16)),
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
("16-mixed", (torch.float16, torch.float16, torch.float16)),
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
("32-true", (torch.float32, torch.float32, torch.float32)),
],
)
def test_fsdp_precision_config(precision, expected):
plugin = FSDPPrecision(precision=precision)
config = plugin.mixed_precision_config

warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx

with warning_ctx(UserWarning, match="enables mixed-precision execution"):
config = plugin.mixed_precision_config

assert config.param_dtype == expected[0]
assert config.buffer_dtype == expected[1]
Expand Down
10 changes: 2 additions & 8 deletions tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,9 @@ def step(self, model, batch):

precision = self.fabric._precision
assert isinstance(precision, FSDPPrecision)
if precision.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif precision.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif precision.precision == "16-true":
if precision.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif precision.precision == "bf16-true":
elif precision.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision {precision.precision}")
Expand Down
18 changes: 15 additions & 3 deletions tests/tests_pytorch/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from unittest.mock import ANY, MagicMock, Mock

import pytest
Expand All @@ -21,19 +22,30 @@
from tests_pytorch.helpers.runif import RunIf


# Pytest passes args/kwargs to the context manager used with `pytest.warns`.
# `contextlib.nullcontext` doesn't accept them, so this no-op version does.
@contextmanager
def null_ctx(*args, **kwargs):
yield


@pytest.mark.parametrize(
("precision", "expected"),
[
("16-true", (torch.float16, torch.float16, torch.float16)),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
("16-mixed", (torch.float32, torch.float16, torch.float16)),
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
("16-mixed", (torch.float16, torch.float16, torch.float16)),
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
("32-true", (torch.float32, torch.float32, torch.float32)),
],
)
def test_fsdp_precision_config(precision, expected):
plugin = FSDPPrecision(precision=precision)
config = plugin.mixed_precision_config

warning_ctx = pytest.warns if precision in ("16-true", "bf16-true") else null_ctx

with warning_ctx(UserWarning, match="enables mixed-precision execution"):
config = plugin.mixed_precision_config

assert config.param_dtype == expected[0]
assert config.buffer_dtype == expected[1]
Expand Down
20 changes: 4 additions & 16 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,9 @@ def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, FullyShardedDataParallel)
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)

if self.trainer.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
if self.trainer.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-true":
elif self.trainer.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision {self.trainer.precision}")
Expand Down Expand Up @@ -138,15 +132,9 @@ def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, torch.nn.Sequential)
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)

if self.trainer.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
if self.trainer.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-true":
elif self.trainer.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision {self.trainer.precision}")
Expand Down
Loading