diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 7df0cb7757f81..836356cbeb79a 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -381,7 +381,7 @@ def log( logger: Optional[bool] = None, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, - reduce_fx: Union[str, Callable] = "mean", + reduce_fx: Union[str, Callable[[Any], Any]] = "mean", enable_graph: bool = False, sync_dist: bool = False, sync_dist_group: Optional[Any] = None, @@ -546,7 +546,7 @@ def log_dict( logger: Optional[bool] = None, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, - reduce_fx: Union[str, Callable] = "mean", + reduce_fx: Union[str, Callable[[Any], Any]] = "mean", enable_graph: bool = False, sync_dist: bool = False, sync_dist_group: Optional[Any] = None,