Skip to content
22 changes: 12 additions & 10 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.fabric.utilities import rank_zero_warn
from lightning.fabric.utilities.types import Optimizable

if TYPE_CHECKING:
Expand Down Expand Up @@ -63,6 +64,14 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
raise ValueError(f"`precision={precision!r}` does not use a scaler, found {scaler}.")

self.scaler = ShardedGradScaler() if scaler is None and precision == "16-mixed" else None

if precision != "32-true":
rank_zero_warn(
f"FSDPPrecision `{precision}` runs computations in reduced precision "
"(e.g., float16/bfloat16) while keeping model weights stored in full precision. "
"These modes are still experimental and may produce slightly different accuracy or stability "
"compared to full precision (`precision='32-true'`)."
)
self.precision = precision

precision_to_type = {
Expand All @@ -84,19 +93,12 @@ def convert_module(self, module: Module) -> Module:
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

if self.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
if self.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
elif self.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "32-true":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float32
param_dtype = reduce_dtype = buffer_dtype = torch.float32
else:
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")

Expand Down
24 changes: 13 additions & 11 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
from lightning.fabric.utilities import rank_zero_warn
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -62,7 +63,15 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
if scaler is not None and self.precision != "16-mixed":
raise ValueError(f"`precision={precision!r}` does not use a scaler, found {scaler}.")

self.scaler = ShardedGradScaler() if scaler is None and precision == "16-mixed" else None
self.scaler = ShardedGradScaler() if scaler is None and precision in ("16-mixed", "16-true") else None

if precision != "32-true":
rank_zero_warn(
f"FSDPPrecision `{precision}` runs computations in reduced precision "
"(e.g., float16/bfloat16) while keeping model weights stored in full precision. "
"These modes are still experimental and may produce slightly different accuracy or stability "
"compared to full precision (`precision='32-true'`)."
)
self.precision = precision

precision_to_type = {
Expand Down Expand Up @@ -94,19 +103,12 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

if self.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
if self.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
elif self.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "32-true":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float32
param_dtype = reduce_dtype = buffer_dtype = torch.float32
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")

Expand Down
24 changes: 16 additions & 8 deletions tests/tests_fabric/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from unittest.mock import Mock

import pytest
Expand All @@ -22,17 +23,24 @@


@pytest.mark.parametrize(
("precision", "expected"),
("precision", "expected", "expect_warn"),
[
("16-true", (torch.float16, torch.float16, torch.float16)),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
("16-mixed", (torch.float32, torch.float16, torch.float16)),
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
("32-true", (torch.float32, torch.float32, torch.float32)),
("16-true", (torch.float16, torch.float16, torch.float16), True),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16), True),
("16-mixed", (torch.float16, torch.float16, torch.float16), True),
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16), True),
("32-true", (torch.float32, torch.float32, torch.float32), False),
],
)
def test_fsdp_precision_config(precision, expected):
plugin = FSDPPrecision(precision=precision)
def test_fsdp_precision_config(precision, expected, expect_warn):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # capture all warnings
plugin = FSDPPrecision(precision=precision)

# Check if the warning was (or wasn’t) logged
has_warn = any("FSDPPrecision" in str(warning.message) for warning in w)
assert has_warn == expect_warn, f"Unexpected warning state for {precision}"

config = plugin.mixed_precision_config

assert config.param_dtype == expected[0]
Expand Down
10 changes: 2 additions & 8 deletions tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,9 @@ def step(self, model, batch):

precision = self.fabric._precision
assert isinstance(precision, FSDPPrecision)
if precision.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif precision.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif precision.precision == "16-true":
if precision.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif precision.precision == "bf16-true":
elif precision.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision {precision.precision}")
Expand Down
24 changes: 16 additions & 8 deletions tests/tests_pytorch/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from unittest.mock import ANY, MagicMock, Mock

import pytest
Expand All @@ -22,17 +23,24 @@


@pytest.mark.parametrize(
("precision", "expected"),
("precision", "expected", "expect_warn"),
[
("16-true", (torch.float16, torch.float16, torch.float16)),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
("16-mixed", (torch.float32, torch.float16, torch.float16)),
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
("32-true", (torch.float32, torch.float32, torch.float32)),
("16-true", (torch.float16, torch.float16, torch.float16), True),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16), True),
("16-mixed", (torch.float16, torch.float16, torch.float16), True),
("bf16-mixed", (torch.bfloat16, torch.bfloat16, torch.bfloat16), True),
("32-true", (torch.float32, torch.float32, torch.float32), False),
],
)
def test_fsdp_precision_config(precision, expected):
plugin = FSDPPrecision(precision=precision)
def test_fsdp_precision_config(precision, expected, expect_warn):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # capture all warnings
plugin = FSDPPrecision(precision=precision)

# Check if the warning was (or wasn’t) logged
has_warn = any("FSDPPrecision" in str(warning.message) for warning in w)
assert has_warn == expect_warn, f"Unexpected warning state for {precision}"

config = plugin.mixed_precision_config

assert config.param_dtype == expected[0]
Expand Down
20 changes: 4 additions & 16 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,9 @@ def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, FullyShardedDataParallel)
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)

if self.trainer.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
if self.trainer.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-true":
elif self.trainer.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision {self.trainer.precision}")
Expand Down Expand Up @@ -138,15 +132,9 @@ def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, torch.nn.Sequential)
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)

if self.trainer.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
if self.trainer.precision in ("16-true", "16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-true":
elif self.trainer.precision in ("bf16-true", "bf16-mixed"):
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision {self.trainer.precision}")
Expand Down
Loading