Skip to content

Commit 1ddd690

Browse files
lightningforeverpre-commit-ci[bot]
authored andcommitted
Add Fabric internal hooks (#17759)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5491e15 commit 1ddd690

File tree

3 files changed

+42
-5
lines changed

3 files changed

+42
-5
lines changed

src/lightning/fabric/fabric.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,14 +212,19 @@ def setup(
212212
# Update the _DeviceDtypeModuleMixin's device parameter
213213
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)
214214

215-
optimizers = [_FabricOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
215+
optimizers = [
216+
_FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks)
217+
for optimizer in optimizers
218+
]
216219

217220
self._models_setup += 1
218221

219222
if hasattr(original_module, "_fabric"): # this is probably a LightningModule
220223
original_module._fabric = self # type: ignore[assignment]
221224
original_module._fabric_optimizers = optimizers # type: ignore[assignment]
222225

226+
self.call("on_after_setup", fabric=self, module=module)
227+
223228
if optimizers:
224229
# join both types in a tuple for API convenience
225230
return (module, *optimizers)
@@ -276,7 +281,10 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, Tu
276281
"""
277282
self._validate_setup_optimizers(optimizers)
278283
optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers]
279-
optimizers = [_FabricOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
284+
optimizers = [
285+
_FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks)
286+
for optimizer in optimizers
287+
]
280288
return optimizers[0] if len(optimizers) == 1 else tuple(optimizers)
281289

282290
def setup_dataloaders(

src/lightning/fabric/wrappers.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import inspect
15-
from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, overload, TypeVar, Union
15+
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, overload, TypeVar, Union
1616

1717
import torch
1818
from lightning_utilities import WarningCache
@@ -38,7 +38,7 @@
3838

3939

4040
class _FabricOptimizer:
41-
def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None:
41+
def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[List[Callable]] = None) -> None:
4242
"""FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the
4343
optimizer step calls to the strategy plugin.
4444
@@ -54,6 +54,7 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy) -> None:
5454
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
5555
self._optimizer = optimizer
5656
self._strategy = strategy
57+
self._callbacks = callbacks or []
5758

5859
@property
5960
def optimizer(self) -> Optimizer:
@@ -69,10 +70,15 @@ def step(self, closure: Optional[Callable] = None) -> Any:
6970
optimizer = self._strategy.model
7071
else:
7172
optimizer = self.optimizer
72-
return self._strategy.optimizer_step(
73+
output = self._strategy.optimizer_step(
7374
optimizer,
7475
**kwargs,
7576
)
77+
for callback in self._callbacks:
78+
hook = getattr(callback, "on_after_optimizer_step", None)
79+
if callable(hook):
80+
hook(strategy=self._strategy, optimizer=optimizer)
81+
return output
7682

7783

7884
class _FabricModule(_DeviceDtypeModuleMixin):

tests/tests_fabric/test_fabric.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,29 @@ def test_call():
753753
assert not callback1.mock_calls
754754

755755

756+
def test_special_callbacks():
757+
"""Tests special callbacks that have hooks for internal Fabric events."""
758+
759+
class SpecialCallback:
760+
def on_after_optimizer_step(self, strategy, optimizer):
761+
pass
762+
763+
def on_after_setup(self, fabric, module):
764+
pass
765+
766+
callback = Mock(wraps=SpecialCallback())
767+
fabric = Fabric(accelerator="cpu", callbacks=[callback])
768+
769+
model = torch.nn.Linear(2, 2)
770+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
771+
fabric_model, fabric_optimizer = fabric.setup(model, optimizer)
772+
callback.on_after_setup.assert_called_once_with(fabric=fabric, module=fabric_model)
773+
774+
model(torch.randn(2, 2)).sum().backward()
775+
fabric_optimizer.step()
776+
callback.on_after_optimizer_step.assert_called_once_with(strategy=fabric._strategy, optimizer=optimizer)
777+
778+
756779
def test_loggers_input():
757780
"""Test the various ways in which loggers can be registered with Fabric."""
758781
logger0 = Mock()

0 commit comments

Comments
 (0)