Skip to content

Commit c36f40c

Browse files
committed
Support gradient norm clipping for FSDP
1 parent aa6e482 commit c36f40c

File tree

5 files changed

+20
-12
lines changed

5 files changed

+20
-12
lines changed

src/lightning/pytorch/core/module.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1207,7 +1207,9 @@ def clip_gradients(
12071207
)
12081208

12091209
gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
1210-
self.trainer.precision_plugin.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm)
1210+
self.trainer.precision_plugin.clip_gradients(
1211+
self.trainer.model, optimizer, gradient_clip_val, gradient_clip_algorithm
1212+
)
12111213

12121214
def configure_gradient_clipping(
12131215
self,

src/lightning/pytorch/plugins/precision/amp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import torch
1717
from torch import Tensor
18+
from torch.nn import Module
1819
from torch.optim import LBFGS, Optimizer
1920
from typing_extensions import override
2021

@@ -100,6 +101,7 @@ def optimizer_step( # type: ignore[override]
100101
@override
101102
def clip_gradients(
102103
self,
104+
module: Module,
103105
optimizer: Optimizer,
104106
clip_val: Union[int, float] = 0.0,
105107
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
@@ -109,7 +111,9 @@ def clip_gradients(
109111
f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
110112
" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
111113
)
112-
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
114+
super().clip_gradients(
115+
module=module, optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm
116+
)
113117

114118
def autocast_context_manager(self) -> torch.autocast:
115119
return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half))

src/lightning/pytorch/plugins/precision/fsdp.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import AbstractContextManager
15-
from typing import TYPE_CHECKING, Any, Callable, Optional
15+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
1616

1717
import torch
1818
from lightning_utilities import apply_to_collection
1919
from torch import Tensor
2020
from torch.nn import Module
21+
from torch.optim import Optimizer
2122
from typing_extensions import get_args, override
2223

2324
import lightning.pytorch as pl
@@ -81,11 +82,9 @@ def convert_module(self, module: Module) -> Module:
8182
return module
8283

8384
@override
84-
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
85+
def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
8586
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
86-
raise MisconfigurationException(
87-
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
88-
)
87+
module.clip_grad_norm_(clip_val)
8988

9089
@property
9190
def mixed_precision_config(self) -> "TorchMixedPrecision":

src/lightning/pytorch/plugins/precision/precision.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def _clip_gradients(
143143

144144
def clip_gradients(
145145
self,
146+
module: Module,
146147
optimizer: Optimizer,
147148
clip_val: Union[int, float] = 0.0,
148149
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
@@ -153,14 +154,14 @@ def clip_gradients(
153154
if gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
154155
self.clip_grad_by_value(optimizer, clip_val)
155156
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
156-
self.clip_grad_by_norm(optimizer, clip_val)
157+
self.clip_grad_by_norm(module, optimizer, clip_val)
157158

158159
def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
159160
"""Clip gradients by value."""
160161
parameters = self.main_params(optimizer)
161162
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)
162163

163-
def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
164+
def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
164165
"""Clip gradients by norm."""
165166
parameters = self.main_params(optimizer)
166167
torch.nn.utils.clip_grad_norm_(parameters, clip_val)

tests/tests_pytorch/plugins/precision/test_amp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from unittest.mock import Mock
1515

1616
import pytest
17+
from torch.nn import Module
1718
from torch.optim import Optimizer
1819

1920
from lightning.pytorch.plugins import MixedPrecision
@@ -22,22 +23,23 @@
2223

2324
def test_clip_gradients():
2425
"""Test that `.clip_gradients()` is a no-op when clipping is disabled."""
26+
module = Mock(spec=Module)
2527
optimizer = Mock(spec=Optimizer)
2628
precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock())
2729
precision.clip_grad_by_value = Mock()
2830
precision.clip_grad_by_norm = Mock()
29-
precision.clip_gradients(optimizer)
31+
precision.clip_gradients(module, optimizer)
3032
precision.clip_grad_by_value.assert_not_called()
3133
precision.clip_grad_by_norm.assert_not_called()
3234

33-
precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE)
35+
precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE)
3436
precision.clip_grad_by_value.assert_called_once()
3537
precision.clip_grad_by_norm.assert_not_called()
3638

3739
precision.clip_grad_by_value.reset_mock()
3840
precision.clip_grad_by_norm.reset_mock()
3941

40-
precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)
42+
precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)
4143
precision.clip_grad_by_value.assert_not_called()
4244
precision.clip_grad_by_norm.assert_called_once()
4345

0 commit comments

Comments
 (0)