@@ -64,14 +64,6 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
6464 raise ValueError (f"`precision={ precision !r} ` does not use a scaler, found { scaler } ." )
6565
6666 self .scaler = ShardedGradScaler () if scaler is None and precision in ("16-mixed" , "16-true" ) else None
67-
68- if precision != "32-true" :
69- rank_zero_warn (
70- f"FSDPPrecision `{ precision } ` runs computations in reduced precision "
71- "(e.g., float16/bfloat16) while keeping model weights stored in full precision. "
72- "These modes are still experimental and may produce slightly different accuracy or stability "
73- "compared to full precision (`precision='32-true'`)."
74- )
7567 self .precision = precision
7668
7769 precision_to_type = {
@@ -103,6 +95,13 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
10395 def mixed_precision_config (self ) -> "TorchMixedPrecision" :
10496 from torch .distributed .fsdp .fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision
10597
98+ if "true" in self .precision and self .precision != "32-true" :
99+ rank_zero_warn (
100+ f"FSDPPrecision `{ self .precision } ` enables mixed-precision execution. "
101+ "Model parameters remain in full precision `torch.float32`, while forward and backward passes "
102+ f"run with reduced precision `{ self ._desired_input_dtype } ` for speed and memory efficiency."
103+ )
104+
106105 if self .precision in ("16-true" , "16-mixed" ):
107106 param_dtype = reduce_dtype = buffer_dtype = torch .float16
108107 elif self .precision in ("bf16-true" , "bf16-mixed" ):
0 commit comments