@@ -985,6 +985,37 @@ def train_function(fabric):
985
985
)
986
986
return self ._wrap_and_launch (function , self , * args , ** kwargs )
987
987
988
+ def _filter_kwargs_for_callback (self , method : Callable , kwargs : dict [str , Any ]) -> dict [str , Any ]:
989
+ """Filter keyword arguments to only include those that match the callback method's signature.
990
+
991
+ Args:
992
+ method: The callback method to inspect
993
+ kwargs: The keyword arguments to filter
994
+
995
+ Returns:
996
+ A filtered dictionary of keyword arguments that match the method's signature
997
+
998
+ """
999
+ try :
1000
+ sig = inspect .signature (method )
1001
+ except (ValueError , TypeError ):
1002
+ # If we can't inspect the signature, pass all kwargs to maintain backward compatibility
1003
+ return kwargs
1004
+
1005
+ filtered_kwargs = {}
1006
+ for name , param in sig .parameters .items ():
1007
+ # Skip 'self' parameter for instance methods
1008
+ if name == "self" :
1009
+ continue
1010
+ # If the method accepts **kwargs, pass all original kwargs directly
1011
+ if param .kind == inspect .Parameter .VAR_KEYWORD :
1012
+ return kwargs
1013
+ # If the parameter exists in the incoming kwargs, add it to filtered_kwargs
1014
+ if name in kwargs :
1015
+ filtered_kwargs [name ] = kwargs [name ]
1016
+
1017
+ return filtered_kwargs
1018
+
988
1019
def call (self , hook_name : str , * args : Any , ** kwargs : Any ) -> None :
989
1020
r"""Trigger the callback methods with the given name and arguments.
990
1021
@@ -994,7 +1025,9 @@ def call(self, hook_name: str, *args: Any, **kwargs: Any) -> None:
994
1025
Args:
995
1026
hook_name: The name of the callback method.
996
1027
*args: Optional positional arguments that get passed down to the callback method.
997
- **kwargs: Optional keyword arguments that get passed down to the callback method.
1028
+ **kwargs: Optional keyword arguments that get passed down to the callback method. Keyword arguments
1029
+ that are not present in the callback's signature will be filtered out automatically, allowing
1030
+ callbacks to have different signatures for the same hook.
998
1031
999
1032
Example::
1000
1033
@@ -1016,13 +1049,8 @@ def on_train_epoch_end(self, results):
1016
1049
)
1017
1050
continue
1018
1051
1019
- method (* args , ** kwargs )
1020
-
1021
- # TODO(fabric): handle the following signatures
1022
- # method(self, fabric|trainer, x, y=1)
1023
- # method(self, fabric|trainer, *args, x, y=1)
1024
- # method(self, *args, y=1)
1025
- # method(self, *args, **kwargs)
1052
+ filtered_kwargs = self ._filter_kwargs_for_callback (method , kwargs )
1053
+ method (* args , ** filtered_kwargs )
1026
1054
1027
1055
def log (self , name : str , value : Any , step : Optional [int ] = None ) -> None :
1028
1056
"""Log a scalar to all loggers that were added to Fabric.
0 commit comments