Skip to content

Incomplete typing in LightningModule.log methodΒ #20941

@gabriele-marino

Description

@gabriele-marino

Bug description

There is an incomplete type definition in the log method of LightningModule, which leads static type checker like pyright to raise errors.

import pytorch_lightning as pl
from torch import Tensor

class MyModel(pl.LightningModule):

    def training_step(self, batch: int, batch_idx: int) -> Tensor:
        self.log("train_loss", 0.123)  # pyright raises: Type of "log" is partially unknown 
        return Tensor([0.123])

This is due to the reduce_fx parameter, that is typed as Union[str, Callable].
Callable should be integrated with the definition of the input and output types.

Thanks!

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions