Skip to content

Commit c656307

Browse files
authored
Handle set_to_none when using DeepSpeed optimizer in Lite (#16275)
1 parent b195b7c commit c656307

File tree

3 files changed

+50
-1
lines changed

3 files changed

+50
-1
lines changed

src/lightning_fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3131
- Added support for managing callbacks via `Fabric(callbacks=...)` and emitting events through `Fabric.call()` ([#16074](https://github.com/Lightning-AI/lightning/issues/16074))
3232

3333

34+
- Added support for a consistent `.zero_grad(set_to_none=...)` on the wrapped optimizer regardless of which strategy is used ([#16275](https://github.com/Lightning-AI/lightning/issues/16275))
35+
36+
3437
### Changed
3538

3639
- Renamed the class `LightningLite` to `Fabric` ([#15932](https://github.com/Lightning-AI/lightning/issues/15932), [#15938](https://github.com/Lightning-AI/lightning/issues/15938))

src/lightning_fabric/wrappers.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import inspect
1415
from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, overload, TypeVar, Union
1516

1617
import torch
@@ -44,7 +45,9 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None:
4445
"""
4546
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
4647
# not want to call on destruction of the `_FabricOptimizer
47-
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")}
48+
self.__dict__ = {
49+
k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "zero_grad", "__del__")
50+
}
4851
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
4952
self._optimizer = optimizer
5053
self._strategy = strategy
@@ -68,6 +71,10 @@ def step(self, closure: Optional[Callable] = None) -> Any:
6871
**kwargs,
6972
)
7073

74+
def zero_grad(self, **kwargs: Any) -> None:
75+
kwargs = _process_optimizer_zero_grad_kwargs(self.optimizer, kwargs)
76+
self.optimizer.zero_grad(**kwargs)
77+
7178

7279
class _FabricModule(_DeviceDtypeModuleMixin):
7380
def __init__(
@@ -175,3 +182,10 @@ def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
175182

176183
for item in iterator:
177184
yield move_data_to_device(item, self._device)
185+
186+
187+
def _process_optimizer_zero_grad_kwargs(optimizer: Optimizer, kwargs: Dict[str, Any]) -> Dict[str, Any]:
188+
if "set_to_none" in kwargs and "set_grads_to_None" in inspect.signature(optimizer.zero_grad).parameters:
189+
# Some optimizers out there, for example DeepSpeedZeroOptimizer, use a different name than PyTorch
190+
kwargs["set_grads_to_None"] = kwargs.pop("set_to_none")
191+
return kwargs

tests/tests_fabric/test_wrappers.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from unittest import mock
1415
from unittest.mock import call, Mock
1516

1617
import pytest
@@ -291,3 +292,34 @@ def test_lite_optimizer_steps():
291292
lite_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
292293
lite_optimizer.step()
293294
strategy.optimizer_step.assert_called_once_with(strategy.model)
295+
296+
297+
def test_fabric_optimizer_zero_grad_kwargs():
298+
"""Test that Fabric can adapt the `.zero_grad()` arguments to the underlying optimizer."""
299+
300+
# Test PyTorch's standard `.zero_grad()` signature
301+
with mock.patch("torch.optim.SGD.zero_grad") as zero_grad_mock:
302+
optimizer = torch.optim.SGD(torch.nn.Linear(1, 1).parameters(), 0.1)
303+
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
304+
fabric_optimizer.zero_grad()
305+
zero_grad_mock.assert_called_with()
306+
fabric_optimizer.zero_grad(set_to_none=False)
307+
zero_grad_mock.assert_called_with(set_to_none=False)
308+
fabric_optimizer.zero_grad(set_to_none=True)
309+
zero_grad_mock.assert_called_with(set_to_none=True)
310+
311+
# Test weird `.zero_grad()` signatures from other libraries
312+
custom_zero_grad = Mock()
313+
314+
class CustomSGD(torch.optim.SGD):
315+
def zero_grad(self, set_grads_to_None=False):
316+
custom_zero_grad(set_grads_to_None=set_grads_to_None)
317+
318+
optimizer = CustomSGD(torch.nn.Linear(1, 1).parameters(), 0.1)
319+
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
320+
fabric_optimizer.zero_grad()
321+
custom_zero_grad.assert_called_with(set_grads_to_None=False)
322+
fabric_optimizer.zero_grad(set_to_none=False)
323+
custom_zero_grad.assert_called_with(set_grads_to_None=False)
324+
fabric_optimizer.zero_grad(set_to_none=True)
325+
custom_zero_grad.assert_called_with(set_grads_to_None=True)

0 commit comments

Comments
 (0)