-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x
Description
Bug description
There is an incomplete type definition in the log method of LightningModule, which leads static type checker like pyright to raise errors.
import pytorch_lightning as pl
from torch import Tensor
class MyModel(pl.LightningModule):
def training_step(self, batch: int, batch_idx: int) -> Tensor:
self.log("train_loss", 0.123) # pyright raises: Type of "log" is partially unknown
return Tensor([0.123])This is due to the reduce_fx parameter, that is typed as Union[str, Callable].
Callable should be integrated with the definition of the input and output types.
Thanks!
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response
rittik9
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x