@@ -985,6 +985,37 @@ def train_function(fabric):
985985 )
986986 return self ._wrap_and_launch (function , self , * args , ** kwargs )
987987
988+ def _filter_kwargs_for_callback (self , method : Callable , kwargs : dict [str , Any ]) -> dict [str , Any ]:
989+ """Filter keyword arguments to only include those that match the callback method's signature.
990+
991+ Args:
992+ method: The callback method to inspect
993+ kwargs: The keyword arguments to filter
994+
995+ Returns:
996+ A filtered dictionary of keyword arguments that match the method's signature
997+
998+ """
999+ try :
1000+ sig = inspect .signature (method )
1001+ except (ValueError , TypeError ):
1002+ # If we can't inspect the signature, pass all kwargs to maintain backward compatibility
1003+ return kwargs
1004+
1005+ filtered_kwargs = {}
1006+ for name , param in sig .parameters .items ():
1007+ # Skip 'self' parameter for instance methods
1008+ if name == "self" :
1009+ continue
1010+ # If the method accepts **kwargs, pass all original kwargs directly
1011+ if param .kind == inspect .Parameter .VAR_KEYWORD :
1012+ return kwargs
1013+ # If the parameter exists in the incoming kwargs, add it to filtered_kwargs
1014+ if name in kwargs :
1015+ filtered_kwargs [name ] = kwargs [name ]
1016+
1017+ return filtered_kwargs
1018+
9881019 def call (self , hook_name : str , * args : Any , ** kwargs : Any ) -> None :
9891020 r"""Trigger the callback methods with the given name and arguments.
9901021
@@ -994,7 +1025,9 @@ def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
9941025 Args:
9951026 hook_name: The name of the callback method.
9961027 *args: Optional positional arguments that get passed down to the callback method.
997- **kwargs: Optional keyword arguments that get passed down to the callback method.
1028+ **kwargs: Optional keyword arguments that get passed down to the callback method. Keyword arguments
1029+ that are not present in the callback's signature will be filtered out automatically, allowing
1030+ callbacks to have different signatures for the same hook.
9981031
9991032 Example::
10001033
@@ -1016,13 +1049,8 @@ def on_train_epoch_end(self, results):
10161049 )
10171050 continue
10181051
1019- method (* args , ** kwargs )
1020-
1021- # TODO(fabric): handle the following signatures
1022- # method(self, fabric|trainer, x, y=1)
1023- # method(self, fabric|trainer, *args, x, y=1)
1024- # method(self, *args, y=1)
1025- # method(self, *args, **kwargs)
1052+ filtered_kwargs = self ._filter_kwargs_for_callback (method , kwargs )
1053+ method (* args , ** filtered_kwargs )
10261054
10271055 def log (self , name : str , value : Any , step : Optional [int ] = None ) -> None :
10281056 """Log a scalar to all loggers that were added to Fabric.
0 commit comments