Skip to content

Commit cba0c48

Browse files
committed
update
1 parent 1e63163 commit cba0c48

File tree

2 files changed

+14
-16
lines changed
  • src/lightning

2 files changed

+14
-16
lines changed

src/lightning/fabric/plugins/precision/fsdp.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,6 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
6464
raise ValueError(f"`precision={precision!r}` does not use a scaler, found {scaler}.")
6565

6666
self.scaler = ShardedGradScaler() if scaler is None and precision == "16-mixed" else None
67-
68-
if precision != "32-true":
69-
rank_zero_warn(
70-
f"FSDPPrecision `{precision}` runs computations in reduced precision "
71-
"(e.g., float16/bfloat16) while keeping model weights stored in full precision. "
72-
"These modes are still experimental and may produce slightly different accuracy or stability "
73-
"compared to full precision (`precision='32-true'`)."
74-
)
7567
self.precision = precision
7668

7769
precision_to_type = {
@@ -93,6 +85,13 @@ def convert_module(self, module: Module) -> Module:
9385
def mixed_precision_config(self) -> "TorchMixedPrecision":
9486
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
9587

88+
if "true" in self.precision and self.precision != "32-true":
89+
rank_zero_warn(
90+
f"FSDPPrecision `{self.precision}` enables mixed-precision execution. "
91+
"Model parameters remain in full precision `torch.float32`, while forward and backward passes "
92+
f"run with reduced precision `{self._desired_input_dtype}` for speed and memory efficiency."
93+
)
94+
9695
if self.precision in ("16-true", "16-mixed"):
9796
param_dtype = reduce_dtype = buffer_dtype = torch.float16
9897
elif self.precision in ("bf16-true", "bf16-mixed"):

src/lightning/pytorch/plugins/precision/fsdp.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,6 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
6464
raise ValueError(f"`precision={precision!r}` does not use a scaler, found {scaler}.")
6565

6666
self.scaler = ShardedGradScaler() if scaler is None and precision in ("16-mixed", "16-true") else None
67-
68-
if precision != "32-true":
69-
rank_zero_warn(
70-
f"FSDPPrecision `{precision}` runs computations in reduced precision "
71-
"(e.g., float16/bfloat16) while keeping model weights stored in full precision. "
72-
"These modes are still experimental and may produce slightly different accuracy or stability "
73-
"compared to full precision (`precision='32-true'`)."
74-
)
7567
self.precision = precision
7668

7769
precision_to_type = {
@@ -103,6 +95,13 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
10395
def mixed_precision_config(self) -> "TorchMixedPrecision":
10496
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
10597

98+
if "true" in self.precision and self.precision != "32-true":
99+
rank_zero_warn(
100+
f"FSDPPrecision `{self.precision}` enables mixed-precision execution. "
101+
"Model parameters remain in full precision `torch.float32`, while forward and backward passes "
102+
f"run with reduced precision `{self._desired_input_dtype}` for speed and memory efficiency."
103+
)
104+
106105
if self.precision in ("16-true", "16-mixed"):
107106
param_dtype = reduce_dtype = buffer_dtype = torch.float16
108107
elif self.precision in ("bf16-true", "bf16-mixed"):

0 commit comments

Comments
 (0)