Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source-fabric/guide/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,30 @@ The :meth:`~lightning.fabric.fabric.Fabric.call` calls the callback objects in t
Not all objects registered via ``Fabric(callbacks=...)`` must implement a method with the given name.
The ones that have a matching method name will get called.

The different callbacks can have different method signatures. Fabric automatically filters keyword arguments based on
each callback's function signature, allowing callbacks with different signatures to work together seamlessly.

.. code-block:: python

class TrainingMetricsCallback:
def on_train_epoch_end(self, train_loss):
print(f"Training loss: {train_loss:.4f}")

class ValidationMetricsCallback:
def on_train_epoch_end(self, val_accuracy):
print(f"Validation accuracy: {val_accuracy:.4f}")

class ComprehensiveCallback:
def on_train_epoch_end(self, epoch, **kwargs):
print(f"Epoch {epoch} complete with metrics: {kwargs}")

fabric = Fabric(
callbacks=[TrainingMetricsCallback(), ValidationMetricsCallback(), ComprehensiveCallback()]
)

# Each callback receives only the arguments it can handle
fabric.call("on_train_epoch_end", epoch=5, train_loss=0.1, val_accuracy=0.95, learning_rate=0.001)


----

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added kwargs-filtering for `Fabric.call` to support different callback method signatures ([#21258](https://github.com/Lightning-AI/pytorch-lightning/pull/21258))


### Removed
Expand Down
44 changes: 36 additions & 8 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,37 @@ def train_function(fabric):
)
return self._wrap_and_launch(function, self, *args, **kwargs)

def _filter_kwargs_for_callback(self, method: Callable, kwargs: dict[str, Any]) -> dict[str, Any]:
"""Filter keyword arguments to only include those that match the callback method's signature.

Args:
method: The callback method to inspect
kwargs: The keyword arguments to filter

Returns:
A filtered dictionary of keyword arguments that match the method's signature

"""
try:
sig = inspect.signature(method)
except (ValueError, TypeError):
# If we can't inspect the signature, pass all kwargs to maintain backward compatibility
return kwargs

filtered_kwargs = {}
for name, param in sig.parameters.items():
# Skip 'self' parameter for instance methods
if name == "self":
continue
# If the method accepts **kwargs, pass all original kwargs directly
if param.kind == inspect.Parameter.VAR_KEYWORD:
return kwargs
# If the parameter exists in the incoming kwargs, add it to filtered_kwargs
if name in kwargs:
filtered_kwargs[name] = kwargs[name]

return filtered_kwargs

def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
r"""Trigger the callback methods with the given name and arguments.

Expand All @@ -994,7 +1025,9 @@ def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
Args:
hook_name: The name of the callback method.
*args: Optional positional arguments that get passed down to the callback method.
**kwargs: Optional keyword arguments that get passed down to the callback method.
**kwargs: Optional keyword arguments that get passed down to the callback method. Keyword arguments
that are not present in the callback's signature will be filtered out automatically, allowing
callbacks to have different signatures for the same hook.

Example::

Expand All @@ -1016,13 +1049,8 @@ def on_train_epoch_end(self, results):
)
continue

method(*args, **kwargs)

# TODO(fabric): handle the following signatures
# method(self, fabric|trainer, x, y=1)
# method(self, fabric|trainer, *args, x, y=1)
# method(self, *args, y=1)
# method(self, *args, **kwargs)
filtered_kwargs = self._filter_kwargs_for_callback(method, kwargs)
method(*args, **filtered_kwargs)

def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
"""Log a scalar to all loggers that were added to Fabric.
Expand Down
55 changes: 54 additions & 1 deletion tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import warnings
from contextlib import nullcontext
Expand All @@ -20,7 +21,6 @@

import pytest
import torch
import torch.distributed
import torch.nn.functional
from lightning_utilities.test.warning import no_warning_call
from torch import nn
Expand Down Expand Up @@ -1294,3 +1294,56 @@ def test_verify_launch_called():
fabric.launch()
assert fabric._launched
fabric._validate_launched()


def test_callback_kwargs_filtering():
"""Test that callbacks receive only the kwargs they can handle based on their signature."""

class CallbackWithLimitedKwargs:
def on_train_epoch_end(self, epoch: int):
self.epoch = epoch

class CallbackWithVarKeywords:
def on_train_epoch_end(self, epoch: int, **kwargs):
self.epoch = epoch
self.kwargs = kwargs

class CallbackWithNoParams:
def on_train_epoch_end(self):
self.called = True

callback1 = CallbackWithLimitedKwargs()
callback2 = CallbackWithVarKeywords()
callback3 = CallbackWithNoParams()
fabric = Fabric(callbacks=[callback1, callback2, callback3])
fabric.call("on_train_epoch_end", epoch=5, loss=0.1, metrics={"acc": 0.9})

assert callback1.epoch == 5
assert not hasattr(callback1, "loss")
assert callback2.epoch == 5
assert callback2.kwargs == {"loss": 0.1, "metrics": {"acc": 0.9}}
assert callback3.called is True


def test_callback_kwargs_filtering_signature_inspection_failure():
"""Test behavior when signature inspection fails - should fallback to passing all kwargs."""
callback = Mock()
fabric = Fabric(callbacks=[callback])
original_signature = inspect.signature

def mock_signature(obj):
if hasattr(obj, "_mock_name") or hasattr(obj, "_mock_new_name"):
raise ValueError("Cannot inspect mock signature")
return original_signature(obj)

# Temporarily replace signature function in fabric module
import lightning.fabric.fabric

lightning.fabric.fabric.inspect.signature = mock_signature

try:
# Should still work by passing all kwargs when signature inspection fails
fabric.call("on_test_hook", arg1="value1", arg2="value2")
callback.on_test_hook.assert_called_with(arg1="value1", arg2="value2")
finally:
lightning.fabric.fabric.inspect.signature = original_signature
Loading