Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2fc2fb7
Update fsdp.py
amorehead May 3, 2025
aa6e482
Merge branch 'Lightning-AI:master' into fsdp-grad-clip-by-norm
amorehead May 3, 2025
c36f40c
Support gradient norm clipping for FSDP
amorehead May 3, 2025
8fad423
Update CHANGELOG.md
amorehead May 3, 2025
04fbaf1
Fix args for certain precisions
amorehead May 3, 2025
bce69ca
Standardize precision args
amorehead May 3, 2025
0df38f5
Guard for typing
amorehead May 3, 2025
a42b974
Fix argument typing
amorehead May 3, 2025
ed2fe05
Wrap AMP test module in FSDP
amorehead May 3, 2025
2f62a0a
Simplify guard
amorehead May 3, 2025
7f7987e
Remove FSDP traces in AMP precision unit test
amorehead May 3, 2025
0b9b2a3
Merge branch 'master' into fsdp-grad-clip-by-norm
amorehead May 10, 2025
f98ce47
Merge branch 'master' into fsdp-grad-clip-by-norm
Borda Aug 19, 2025
5814091
Merge branch 'master' into fsdp-grad-clip-by-norm
amorehead Aug 19, 2025
75d6d9f
Merge branch 'master' into fsdp-grad-clip-by-norm
Borda Sep 3, 2025
395c7fd
Merge branch 'master' into fsdp-grad-clip-by-norm
Borda Sep 3, 2025
6f04f9c
Merge branch 'master' into fsdp-grad-clip-by-norm
amorehead Sep 3, 2025
dee2225
Apply suggestions from code review
Borda Sep 4, 2025
169e20c
Merge branch 'master' into fsdp-grad-clip-by-norm
amorehead Sep 4, 2025
3d80102
Merge branch 'master' into fsdp-grad-clip-by-norm
amorehead Sep 5, 2025
de84676
Update module.py
amorehead Sep 10, 2025
188ca22
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
7c829b6
Update amp.py
amorehead Sep 10, 2025
161241e
Update deepspeed.py
amorehead Sep 10, 2025
eea0a94
Update fsdp.py
amorehead Sep 10, 2025
181a355
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
759c4a2
Update precision.py
amorehead Sep 10, 2025
9bc3991
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
63e9d3a
Update test_amp.py
amorehead Sep 10, 2025
46fb1b5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
9d9cd47
Merge branch 'master' into fsdp-grad-clip-by-norm
amorehead Sep 10, 2025
6b631c3
Merge branch 'master' into fsdp-grad-clip-by-norm
justusschock Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Support `grad_clip_norm_()` for FSDP ([#20784](https://github.com/Lightning-AI/pytorch-lightning/pull/20784))


- Added `WeightAveraging` callback that wraps the PyTorch `AveragedModel` class ([#20545](https://github.com/Lightning-AI/pytorch-lightning/pull/20545))


Expand Down
7 changes: 6 additions & 1 deletion src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,12 @@ def clip_gradients(
)

gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
self.trainer.precision_plugin.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm)
self.trainer.precision_plugin.clip_gradients(
optimizer,
gradient_clip_val,
gradient_clip_algorithm,
module=self.trainer.model,
)

def configure_gradient_clipping(
self,
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS, Optimizer
from typing_extensions import override

Expand Down Expand Up @@ -103,13 +104,16 @@ def clip_gradients(
optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
module: Optional[Module] = None,
) -> None:
if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
raise RuntimeError(
f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
)
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
super().clip_gradients(
optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm, module=module
)

def autocast_context_manager(self) -> torch.autocast:
dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.half
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,6 @@ def clip_gradients(
optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
module: Optional[Module] = None,
) -> None:
"""DeepSpeed handles gradient clipping internally."""
17 changes: 9 additions & 8 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import AbstractContextManager
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from typing_extensions import get_args, override

import lightning.pytorch as pl
Expand Down Expand Up @@ -81,14 +82,14 @@ def convert_module(self, module: Module) -> Module:
return module

@override
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
def clip_grad_by_norm(
self, optimizer: Optimizer, clip_val: Union[int, float], module: Optional[Module] = None
) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP.
# To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference
# to the root module
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)
if module is None:
return
assert isinstance(module.clip_grad_norm_, Module)
module.clip_grad_norm_(clip_val)

@property
def mixed_precision_config(self) -> "TorchMixedPrecision":
Expand Down
7 changes: 5 additions & 2 deletions src/lightning/pytorch/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,21 +146,24 @@ def clip_gradients(
optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
module: Optional[Module] = None,
) -> None:
"""Clips the gradients."""
if clip_val <= 0:
return
if gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
self.clip_grad_by_value(optimizer, clip_val)
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
self.clip_grad_by_norm(optimizer, clip_val)
self.clip_grad_by_norm(optimizer, clip_val, module=module)

def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by value."""
parameters = self.main_params(optimizer)
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)

def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
def clip_grad_by_norm(
self, optimizer: Optimizer, clip_val: Union[int, float], module: Optional[Module] = None
) -> None:
"""Clip gradients by norm."""
parameters = self.main_params(optimizer)
torch.nn.utils.clip_grad_norm_(parameters, clip_val)
Expand Down
12 changes: 8 additions & 4 deletions tests/tests_pytorch/plugins/precision/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,25 @@

def test_clip_gradients():
"""Test that `.clip_gradients()` is a no-op when clipping is disabled."""
module = Mock(spec=nn.Module)
optimizer = Mock(spec=Optimizer)
precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock())
precision.clip_grad_by_value = Mock()
precision.clip_grad_by_norm = Mock()
precision.clip_gradients(optimizer)
precision.clip_gradients(optimizer, module=module)
precision.clip_grad_by_value.assert_not_called()
precision.clip_grad_by_norm.assert_not_called()

precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE)
precision.clip_gradients(
optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE, module=module
)
precision.clip_grad_by_value.assert_called_once()
precision.clip_grad_by_norm.assert_not_called()

precision.clip_grad_by_value.reset_mock()
precision.clip_grad_by_norm.reset_mock()

precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)
precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM, module=module)
precision.clip_grad_by_value.assert_not_called()
precision.clip_grad_by_norm.assert_called_once()

Expand All @@ -48,11 +51,12 @@ def test_optimizer_amp_scaling_support_in_step_method():
"""Test that the plugin checks if the optimizer takes over unscaling in its step, making it incompatible with
gradient clipping (example: fused Adam)."""

module = Mock(spec=nn.Module)
optimizer = Mock(_step_supports_amp_scaling=True)
precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock())

with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"):
precision.clip_gradients(optimizer, clip_val=1.0)
precision.clip_gradients(optimizer, clip_val=1.0, module=module)


def test_amp_with_no_grad():
Expand Down
Loading