diff --git a/docs/source-fabric/guide/callbacks.rst b/docs/source-fabric/guide/callbacks.rst index 2a9a0e482d3ea..079ec5f27a6c2 100644 --- a/docs/source-fabric/guide/callbacks.rst +++ b/docs/source-fabric/guide/callbacks.rst @@ -83,6 +83,30 @@ The :meth:`~lightning.fabric.fabric.Fabric.call` calls the callback objects in t Not all objects registered via ``Fabric(callbacks=...)`` must implement a method with the given name. The ones that have a matching method name will get called. +The different callbacks can have different method signatures. Fabric automatically filters keyword arguments based on +each callback's function signature, allowing callbacks with different signatures to work together seamlessly. + +.. code-block:: python + + class TrainingMetricsCallback: + def on_train_epoch_end(self, train_loss): + print(f"Training loss: {train_loss:.4f}") + + class ValidationMetricsCallback: + def on_train_epoch_end(self, val_accuracy): + print(f"Validation accuracy: {val_accuracy:.4f}") + + class ComprehensiveCallback: + def on_train_epoch_end(self, epoch, **kwargs): + print(f"Epoch {epoch} complete with metrics: {kwargs}") + + fabric = Fabric( + callbacks=[TrainingMetricsCallback(), ValidationMetricsCallback(), ComprehensiveCallback()] + ) + + # Each callback receives only the arguments it can handle + fabric.call("on_train_epoch_end", epoch=5, train_loss=0.1, val_accuracy=0.95, learning_rate=0.001) + ---- diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 0e1cc944a3492..d22a72676bc1d 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added kwargs-filtering for `Fabric.call` to support different callback method signatures ([#21258](https://github.com/Lightning-AI/pytorch-lightning/pull/21258)) ### Removed diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 288c355a4ebf2..7306cabc9472e 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -985,6 +985,37 @@ def train_function(fabric): ) return self._wrap_and_launch(function, self, *args, **kwargs) + def _filter_kwargs_for_callback(self, method: Callable, kwargs: dict[str, Any]) -> dict[str, Any]: + """Filter keyword arguments to only include those that match the callback method's signature. + + Args: + method: The callback method to inspect + kwargs: The keyword arguments to filter + + Returns: + A filtered dictionary of keyword arguments that match the method's signature + + """ + try: + sig = inspect.signature(method) + except (ValueError, TypeError): + # If we can't inspect the signature, pass all kwargs to maintain backward compatibility + return kwargs + + filtered_kwargs = {} + for name, param in sig.parameters.items(): + # Skip 'self' parameter for instance methods + if name == "self": + continue + # If the method accepts **kwargs, pass all original kwargs directly + if param.kind == inspect.Parameter.VAR_KEYWORD: + return kwargs + # If the parameter exists in the incoming kwargs, add it to filtered_kwargs + if name in kwargs: + filtered_kwargs[name] = kwargs[name] + + return filtered_kwargs + def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None: r"""Trigger the callback methods with the given name and arguments. @@ -994,7 +1025,9 @@ def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None: Args: hook_name: The name of the callback method. *args: Optional positional arguments that get passed down to the callback method. - **kwargs: Optional keyword arguments that get passed down to the callback method. + **kwargs: Optional keyword arguments that get passed down to the callback method. Keyword arguments + that are not present in the callback's signature will be filtered out automatically, allowing + callbacks to have different signatures for the same hook. Example:: @@ -1016,13 +1049,8 @@ def on_train_epoch_end(self, results): ) continue - method(*args, **kwargs) - - # TODO(fabric): handle the following signatures - # method(self, fabric|trainer, x, y=1) - # method(self, fabric|trainer, *args, x, y=1) - # method(self, *args, y=1) - # method(self, *args, **kwargs) + filtered_kwargs = self._filter_kwargs_for_callback(method, kwargs) + method(*args, **filtered_kwargs) def log(self, name: str, value: Any, step: Optional[int] = None) -> None: """Log a scalar to all loggers that were added to Fabric. diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index dc0203dc067e3..8bfc5002de6a4 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os import warnings from contextlib import nullcontext @@ -20,7 +21,6 @@ import pytest import torch -import torch.distributed import torch.nn.functional from lightning_utilities.test.warning import no_warning_call from torch import nn @@ -1294,3 +1294,56 @@ def test_verify_launch_called(): fabric.launch() assert fabric._launched fabric._validate_launched() + + +def test_callback_kwargs_filtering(): + """Test that callbacks receive only the kwargs they can handle based on their signature.""" + + class CallbackWithLimitedKwargs: + def on_train_epoch_end(self, epoch: int): + self.epoch = epoch + + class CallbackWithVarKeywords: + def on_train_epoch_end(self, epoch: int, **kwargs): + self.epoch = epoch + self.kwargs = kwargs + + class CallbackWithNoParams: + def on_train_epoch_end(self): + self.called = True + + callback1 = CallbackWithLimitedKwargs() + callback2 = CallbackWithVarKeywords() + callback3 = CallbackWithNoParams() + fabric = Fabric(callbacks=[callback1, callback2, callback3]) + fabric.call("on_train_epoch_end", epoch=5, loss=0.1, metrics={"acc": 0.9}) + + assert callback1.epoch == 5 + assert not hasattr(callback1, "loss") + assert callback2.epoch == 5 + assert callback2.kwargs == {"loss": 0.1, "metrics": {"acc": 0.9}} + assert callback3.called is True + + +def test_callback_kwargs_filtering_signature_inspection_failure(): + """Test behavior when signature inspection fails - should fallback to passing all kwargs.""" + callback = Mock() + fabric = Fabric(callbacks=[callback]) + original_signature = inspect.signature + + def mock_signature(obj): + if hasattr(obj, "_mock_name") or hasattr(obj, "_mock_new_name"): + raise ValueError("Cannot inspect mock signature") + return original_signature(obj) + + # Temporarily replace signature function in fabric module + import lightning.fabric.fabric + + lightning.fabric.fabric.inspect.signature = mock_signature + + try: + # Should still work by passing all kwargs when signature inspection fails + fabric.call("on_test_hook", arg1="value1", arg2="value2") + callback.on_test_hook.assert_called_with(arg1="value1", arg2="value2") + finally: + lightning.fabric.fabric.inspect.signature = original_signature