Skip to content

Commit d8a4d84

Browse files
committed
add fsdp2 precision plugin
1 parent 9f537c1 commit d8a4d84

File tree

2 files changed

+116
-6
lines changed

2 files changed

+116
-6
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from contextlib import AbstractContextManager
15+
from typing import Any
16+
17+
import torch
18+
from lightning_utilities import apply_to_collection
19+
from torch import Tensor
20+
from torch.nn import Module
21+
from typing_extensions import get_args, override
22+
23+
from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT
24+
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor, _DtypeContextManager
25+
from lightning.pytorch.plugins.precision.precision import Precision
26+
from lightning.pytorch.utilities.exceptions import MisconfigurationException
27+
28+
29+
class FSDP2Precision(Precision):
30+
"""Precision plugin for training with FSDP2 (Fully Sharded Data Parallel v2).
31+
32+
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
33+
34+
Args:
35+
precision: Full precision (32-true), half precision (16-true, bf16-true) or
36+
mixed precision (16-mixed, bf16-mixed).
37+
scaler: An optional :class:`torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler` to use.
38+
39+
Raises:
40+
ValueError:
41+
If unsupported ``precision`` is provided.
42+
43+
"""
44+
45+
def __init__(self, precision: _PRECISION_INPUT, scaler: Any = None) -> None:
46+
supported_precision = get_args(_PRECISION_INPUT)
47+
if precision not in supported_precision:
48+
raise ValueError(
49+
f"`precision={precision!r})` is not supported in FSDP."
50+
f" `precision` must be one of: {supported_precision}."
51+
)
52+
53+
if scaler is not None:
54+
raise ValueError(
55+
f"`scaler` is not supported in `{self.__class__.__name__}`, found {scaler}."
56+
"Use `mixed-precision policy` instead to configure the scaler."
57+
)
58+
59+
if "mixed" in precision:
60+
raise ValueError(
61+
f"`precision={precision!r}` is not supported in `{self.__class__.__name__}`."
62+
"Only `true` precision is supported."
63+
"Use `mixed-precision policy (mp_policy)` instead to configure mixed precision."
64+
)
65+
66+
self.precision = precision
67+
68+
precision_to_type = {
69+
"bf16-true": torch.bfloat16,
70+
"16-true": torch.float16,
71+
"32-true": torch.float32,
72+
}
73+
self._desired_input_dtype = precision_to_type[self.precision]
74+
75+
@override
76+
def convert_module(self, module: Module) -> Module:
77+
if "true" in self.precision:
78+
return module.to(dtype=self._desired_input_dtype)
79+
return module
80+
81+
@override
82+
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
83+
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
84+
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP.
85+
# To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference
86+
# to the root module
87+
raise MisconfigurationException(
88+
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
89+
)
90+
91+
@override
92+
def tensor_init_context(self) -> AbstractContextManager:
93+
return _DtypeContextManager(self._desired_input_dtype)
94+
95+
@override
96+
def module_init_context(self) -> AbstractContextManager:
97+
# Use float32 for module parameter initialization to ensure numerical stability
98+
return _DtypeContextManager(self._desired_input_dtype)
99+
100+
@override
101+
def forward_context(self) -> AbstractContextManager:
102+
return _DtypeContextManager(self._desired_input_dtype)
103+
104+
@override
105+
def convert_input(self, data: Any) -> Any:
106+
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
107+
108+
@override
109+
def convert_output(self, data: Any) -> Any:
110+
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())

src/lightning/pytorch/strategies/fsdp2.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
from lightning.fabric.utilities.types import _PATH, ReduceOp
6666
from lightning.pytorch.core.optimizer import LightningOptimizer
6767
from lightning.pytorch.plugins.precision import Precision
68-
from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision
68+
from lightning.pytorch.plugins.precision.fsdp2 import FSDP2Precision
6969
from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
7070
from lightning.pytorch.strategies.parallel import ParallelStrategy
7171
from lightning.pytorch.strategies.strategy import TBroadcast
@@ -173,19 +173,19 @@ def process_group_backend(self) -> Optional[str]:
173173

174174
@property
175175
@override
176-
def precision_plugin(self) -> FSDPPrecision:
176+
def precision_plugin(self) -> FSDP2Precision:
177177
plugin = self._precision_plugin
178178
if plugin is not None:
179-
assert isinstance(plugin, FSDPPrecision)
179+
assert isinstance(plugin, FSDP2Precision)
180180
return plugin
181-
return FSDPPrecision("32-true")
181+
return FSDP2Precision("32-true")
182182

183183
@precision_plugin.setter
184184
@override
185185
def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
186-
if precision_plugin is not None and not isinstance(precision_plugin, FSDPPrecision):
186+
if precision_plugin is not None and not isinstance(precision_plugin, FSDP2Precision):
187187
raise TypeError(
188-
f"The FSDP strategy can only work with the `FSDPPrecision` plugin, found {precision_plugin}"
188+
f"The FSDP2 strategy can only work with the `FSDP2Precision` plugin, found {precision_plugin}"
189189
)
190190
self._precision_plugin = precision_plugin
191191

0 commit comments

Comments
 (0)