Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2fc2fb7
Update fsdp.py
amorehead May 3, 2025
aa6e482
Merge branch 'Lightning-AI:master' into fsdp-grad-clip-by-norm
amorehead May 3, 2025
c36f40c
Support gradient norm clipping for FSDP
amorehead May 3, 2025
8fad423
Update CHANGELOG.md
amorehead May 3, 2025
04fbaf1
Fix args for certain precisions
amorehead May 3, 2025
bce69ca
Standardize precision args
amorehead May 3, 2025
0df38f5
Guard for typing
amorehead May 3, 2025
a42b974
Fix argument typing
amorehead May 3, 2025
ed2fe05
Wrap AMP test module in FSDP
amorehead May 3, 2025
2f62a0a
Simplify guard
amorehead May 3, 2025
7f7987e
Remove FSDP traces in AMP precision unit test
amorehead May 3, 2025
0b9b2a3
Merge branch 'master' into fsdp-grad-clip-by-norm
amorehead May 10, 2025
f98ce47
Merge branch 'master' into fsdp-grad-clip-by-norm
Borda Aug 19, 2025
5814091
Merge branch 'master' into fsdp-grad-clip-by-norm
amorehead Aug 19, 2025
75d6d9f
Merge branch 'master' into fsdp-grad-clip-by-norm
Borda Sep 3, 2025
395c7fd
Merge branch 'master' into fsdp-grad-clip-by-norm
Borda Sep 3, 2025
6f04f9c
Merge branch 'master' into fsdp-grad-clip-by-norm
amorehead Sep 3, 2025
dee2225
Apply suggestions from code review
Borda Sep 4, 2025
169e20c
Merge branch 'master' into fsdp-grad-clip-by-norm
amorehead Sep 4, 2025
3d80102
Merge branch 'master' into fsdp-grad-clip-by-norm
amorehead Sep 5, 2025
de84676
Update module.py
amorehead Sep 10, 2025
188ca22
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
7c829b6
Update amp.py
amorehead Sep 10, 2025
161241e
Update deepspeed.py
amorehead Sep 10, 2025
eea0a94
Update fsdp.py
amorehead Sep 10, 2025
181a355
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
759c4a2
Update precision.py
amorehead Sep 10, 2025
9bc3991
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
63e9d3a
Update test_amp.py
amorehead Sep 10, 2025
46fb1b5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
9d9cd47
Merge branch 'master' into fsdp-grad-clip-by-norm
amorehead Sep 10, 2025
6b631c3
Merge branch 'master' into fsdp-grad-clip-by-norm
justusschock Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Support `grad_clip_norm_()` for FSDP ([#20784](https://github.com/Lightning-AI/pytorch-lightning/pull/20784))


- Added `WeightAveraging` callback that wraps the PyTorch `AveragedModel` class ([#20545](https://github.com/Lightning-AI/pytorch-lightning/pull/20545))


Expand Down
4 changes: 3 additions & 1 deletion src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,9 @@ def clip_gradients(
)

gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
self.trainer.precision_plugin.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm)
self.trainer.precision_plugin.clip_gradients(
self.trainer.model, optimizer, gradient_clip_val, gradient_clip_algorithm
)

def configure_gradient_clipping(
self,
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS, Optimizer
from typing_extensions import override

Expand Down Expand Up @@ -100,6 +101,7 @@ def optimizer_step( # type: ignore[override]
@override
def clip_gradients(
self,
module: Optional[Module],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be a breaking change, it has to go to the end of arguments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to how other codebases like Fabric would call clip_gradients? As far as I can see with this PR's unit tests, all references in the Lightning codebase are not broken by this change. And if you are, for clarification, would module have to be made a module: Optional[Module] = None as the last argument in all of the modified functions below?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am saying if user is using positional arguments this will break for him

Copy link
Contributor Author

@amorehead amorehead Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point. I've just made the new module argument fully optional by listing it as the last optional argument module: Optional[Module] = None. Let me know if you can see anything else that needs to be addressed.

optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
Expand All @@ -109,7 +111,9 @@ def clip_gradients(
f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
)
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
super().clip_gradients(
module=module, optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)

def autocast_context_manager(self) -> torch.autocast:
dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.half
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def optimizer_step( # type: ignore[override]
@override
def clip_gradients(
self,
module: Optional[Module],
optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
Expand Down
14 changes: 6 additions & 8 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import AbstractContextManager
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from typing_extensions import get_args, override

import lightning.pytorch as pl
Expand Down Expand Up @@ -81,14 +82,11 @@ def convert_module(self, module: Module) -> Module:
return module

@override
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP.
# To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference
# to the root module
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)
if module is None:
return
module.clip_grad_norm_(clip_val)

@property
def mixed_precision_config(self) -> "TorchMixedPrecision":
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/pytorch/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def _clip_gradients(

def clip_gradients(
self,
module: Optional[Module],
optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
Expand All @@ -153,14 +154,14 @@ def clip_gradients(
if gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
self.clip_grad_by_value(optimizer, clip_val)
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
self.clip_grad_by_norm(optimizer, clip_val)
self.clip_grad_by_norm(module, optimizer, clip_val)

def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by value."""
parameters = self.main_params(optimizer)
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)

def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by norm."""
parameters = self.main_params(optimizer)
torch.nn.utils.clip_grad_norm_(parameters, clip_val)
Expand Down
10 changes: 6 additions & 4 deletions tests/tests_pytorch/plugins/precision/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,23 @@

def test_clip_gradients():
"""Test that `.clip_gradients()` is a no-op when clipping is disabled."""
module = Mock(spec=nn.Module)
optimizer = Mock(spec=Optimizer)
precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock())
precision.clip_grad_by_value = Mock()
precision.clip_grad_by_norm = Mock()
precision.clip_gradients(optimizer)
precision.clip_gradients(module, optimizer)
precision.clip_grad_by_value.assert_not_called()
precision.clip_grad_by_norm.assert_not_called()

precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE)
precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE)
precision.clip_grad_by_value.assert_called_once()
precision.clip_grad_by_norm.assert_not_called()

precision.clip_grad_by_value.reset_mock()
precision.clip_grad_by_norm.reset_mock()

precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)
precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)
precision.clip_grad_by_value.assert_not_called()
precision.clip_grad_by_norm.assert_called_once()

Expand All @@ -48,11 +49,12 @@ def test_optimizer_amp_scaling_support_in_step_method():
"""Test that the plugin checks if the optimizer takes over unscaling in its step, making it incompatible with
gradient clipping (example: fused Adam)."""

module = Mock(spec=nn.Module)
optimizer = Mock(_step_supports_amp_scaling=True)
precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock())

with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"):
precision.clip_gradients(optimizer, clip_val=1.0)
precision.clip_gradients(module, optimizer, clip_val=1.0)


def test_amp_with_no_grad():
Expand Down
Loading