Skip to content

Commit 46874df

Browse files
authored
Improve type hint for reduce_fx in LightningModule.log (#20943)
1 parent a651975 commit 46874df

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/lightning/pytorch/core/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def log(
381381
logger: Optional[bool] = None,
382382
on_step: Optional[bool] = None,
383383
on_epoch: Optional[bool] = None,
384-
reduce_fx: Union[str, Callable] = "mean",
384+
reduce_fx: Union[str, Callable[[Any], Any]] = "mean",
385385
enable_graph: bool = False,
386386
sync_dist: bool = False,
387387
sync_dist_group: Optional[Any] = None,
@@ -546,7 +546,7 @@ def log_dict(
546546
logger: Optional[bool] = None,
547547
on_step: Optional[bool] = None,
548548
on_epoch: Optional[bool] = None,
549-
reduce_fx: Union[str, Callable] = "mean",
549+
reduce_fx: Union[str, Callable[[Any], Any]] = "mean",
550550
enable_graph: bool = False,
551551
sync_dist: bool = False,
552552
sync_dist_group: Optional[Any] = None,

0 commit comments

Comments
 (0)