Skip to content

Commit dbc6e96

Browse files
committed
add filtering of kwargs
1 parent 56908b3 commit dbc6e96

File tree

1 file changed

+36
-8
lines changed

1 file changed

+36
-8
lines changed

src/lightning/fabric/fabric.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,37 @@ def train_function(fabric):
985985
)
986986
return self._wrap_and_launch(function, self, *args, **kwargs)
987987

988+
def _filter_kwargs_for_callback(self, method: Callable, kwargs: dict[str, Any]) -> dict[str, Any]:
989+
"""Filter keyword arguments to only include those that match the callback method's signature.
990+
991+
Args:
992+
method: The callback method to inspect
993+
kwargs: The keyword arguments to filter
994+
995+
Returns:
996+
A filtered dictionary of keyword arguments that match the method's signature
997+
998+
"""
999+
try:
1000+
sig = inspect.signature(method)
1001+
except (ValueError, TypeError):
1002+
# If we can't inspect the signature, pass all kwargs to maintain backward compatibility
1003+
return kwargs
1004+
1005+
filtered_kwargs = {}
1006+
for name, param in sig.parameters.items():
1007+
# Skip 'self' parameter for instance methods
1008+
if name == "self":
1009+
continue
1010+
# If the method accepts **kwargs, pass all original kwargs directly
1011+
if param.kind == inspect.Parameter.VAR_KEYWORD:
1012+
return kwargs
1013+
# If the parameter exists in the incoming kwargs, add it to filtered_kwargs
1014+
if name in kwargs:
1015+
filtered_kwargs[name] = kwargs[name]
1016+
1017+
return filtered_kwargs
1018+
9881019
def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
9891020
r"""Trigger the callback methods with the given name and arguments.
9901021
@@ -994,7 +1025,9 @@ def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
9941025
Args:
9951026
hook_name: The name of the callback method.
9961027
*args: Optional positional arguments that get passed down to the callback method.
997-
**kwargs: Optional keyword arguments that get passed down to the callback method.
1028+
**kwargs: Optional keyword arguments that get passed down to the callback method. Keyword arguments
1029+
that are not present in the callback's signature will be filtered out automatically, allowing
1030+
callbacks to have different signatures for the same hook.
9981031
9991032
Example::
10001033
@@ -1016,13 +1049,8 @@ def on_train_epoch_end(self, results):
10161049
)
10171050
continue
10181051

1019-
method(*args, **kwargs)
1020-
1021-
# TODO(fabric): handle the following signatures
1022-
# method(self, fabric|trainer, x, y=1)
1023-
# method(self, fabric|trainer, *args, x, y=1)
1024-
# method(self, *args, y=1)
1025-
# method(self, *args, **kwargs)
1052+
filtered_kwargs = self._filter_kwargs_for_callback(method, kwargs)
1053+
method(*args, **filtered_kwargs)
10261054

10271055
def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
10281056
"""Log a scalar to all loggers that were added to Fabric.

0 commit comments

Comments
 (0)