Skip to content

Type annotation for BasePredictionWriter subclass #20356

@saiden89

Description

@saiden89

Bug description

Subclassing the BasePredictionWriter for custom functionality results in Pylance complaining about incorrect type.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

import lightning as L
from lightning.pytorch.callbacks import BasePredictionWriter

if TYPE_CHECKING:
    import polars as pl
    from lightning.pytorch import LightningModule, Trainer
    from torch import Tensor


class ParquetWriter(BasePredictionWriter):
    """
    Callback for writing predictions to Parquet files.

    Parameters
    ----------
    output_dir
        The directory where the parquet files will be written.
    write_interval
        The interval at which the predictions will be written.
    """

    def __init__(self, output_dir: str, write_interval: Literal["batch"]) -> None:
        super().__init__(write_interval)
        self.output_dir = Path(output_dir)

    def write_on_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,  # noqa: ARG002
        prediction: pl.DataFrame,
        batch_indices: Tensor,  # noqa: ARG002
        batch: dict[str, Any],  # noqa: ARG002
        batch_idx: int,
        dataloader_idx: int,  # noqa: ARG002
    ) -> None:
        """Write the prediction to a parquet file."""
        prediction.write_parquet(
            self.output_dir / f"{trainer.global_rank}{batch_idx}.parquet",
        )


callbacks = [
    ParquetWriter(
        output_dir="/tmp",
        write_interval="batch",
    ),
]
trainer = L.Trainer(
    callbacks=callbacks, <----- Pylance(reportArgumentType)
)

Error messages and logs

Argument of type "list[ParquetWriter]" cannot be assigned to parameter "callbacks" of type "List[Callback] | Callback | None" in function "__init__"
  Type "list[ParquetWriter]" is not assignable to type "List[Callback] | Callback | None"
    "list[ParquetWriter]" is not assignable to "List[Callback]"
      Type parameter "_T@list" is invariant, but "ParquetWriter" is not the same as "Callback"
      Consider switching from "list" to "Sequence" which is covariant
    "list[ParquetWriter]" is not assignable to "Callback"
    "list[ParquetWriter]" is not assignable to "None"

Environment

Current environment
  • CUDA:
    • GPU: None
    • available: False
    • version: None
  • Lightning:
    • lightning: 2.4.0
    • lightning-utilities: 0.11.7
    • pytorch-lightning: 2.4.0
    • torch: 2.4.1
    • torchaudio: 2.4.1
    • torchmetrics: 1.4.2
    • torchvision: 0.19.1
  • Packages:
    • aenum: 3.1.12
    • aiohappyeyeballs: 2.4.3
    • aiohttp: 3.10.8
    • aiosignal: 1.3.1
    • altair: 5.4.1
    • annotated-types: 0.7.0
    • antlr4-python3-runtime: 4.9.3
    • anyio: 4.6.0
    • appnope: 0.1.4
    • argon2-cffi: 23.1.0
    • argon2-cffi-bindings: 21.2.0
    • arrow: 1.3.0
    • asttokens: 2.4.1
    • async-lru: 2.0.4
    • attrs: 24.2.0
    • autocommand: 2.2.2
    • babel: 2.16.0
    • backports.tarfile: 1.2.0
    • beautifulsoup4: 4.12.3
    • bitsandbytes: 0.42.0
    • bleach: 6.1.0
    • certifi: 2024.8.30
    • cffi: 1.17.1
    • charset-normalizer: 3.3.2
    • click: 8.1.7
    • comm: 0.2.2
    • contourpy: 1.3.0
    • crispron: 3.0
    • cycler: 0.12.1
    • datasets: 3.0.1
    • debugpy: 1.8.6
    • decorator: 5.1.1
    • defusedxml: 0.7.1
    • dill: 0.3.8
    • docker-pycreds: 0.4.0
    • docstring-parser: 0.16
    • euporie: 2.8.3
    • executing: 2.1.0
    • fastjsonschema: 2.20.0
    • filelock: 3.16.1
    • flatlatex: 0.15
    • fonttools: 4.54.1
    • fqdn: 1.5.1
    • frozenlist: 1.4.1
    • fsspec: 2024.6.1
    • gitdb: 4.0.11
    • gitpython: 3.1.43
    • h11: 0.14.0
    • httpcore: 1.0.6
    • httpx: 0.27.2
    • huggingface-hub: 0.25.1
    • hydra-core: 1.3.2
    • idna: 3.10
    • imagesize: 1.4.1
    • importlib-metadata: 8.0.0
    • importlib-resources: 6.4.5
    • inflect: 7.3.1
    • ipykernel: 6.29.5
    • ipython: 8.28.0
    • ipywidgets: 8.1.5
    • isoduration: 20.11.0
    • itables: 2.2.2
    • jaraco.collections: 5.1.0
    • jaraco.context: 5.3.0
    • jaraco.functools: 4.0.1
    • jaraco.text: 3.12.1
    • jedi: 0.19.1
    • jinja2: 3.1.4
    • joblib: 1.4.2
    • json5: 0.9.25
    • jsonargparse: 4.33.1
    • jsonpointer: 3.0.0
    • jsonschema: 4.23.0
    • jsonschema-specifications: 2023.12.1
    • jupyter-client: 8.6.3
    • jupyter-core: 5.7.2
    • jupyter-events: 0.10.0
    • jupyter-lsp: 2.2.5
    • jupyter-server: 2.14.2
    • jupyter-server-terminals: 0.5.3
    • jupyterlab: 4.2.5
    • jupyterlab-pygments: 0.3.0
    • jupyterlab-server: 2.27.3
    • jupyterlab-widgets: 3.0.13
    • jupytext: 1.16.4
    • kiwisolver: 1.4.7
    • lightning: 2.4.0
    • lightning-utilities: 0.11.7
    • linkify-it-py: 1.0.3
    • markdown-it-py: 2.2.0
    • markupsafe: 2.1.5
    • matplotlib: 3.9.2
    • matplotlib-inline: 0.1.7
    • mdit-py-plugins: 0.3.5
    • mdurl: 0.1.2
    • mistune: 3.0.2
    • more-itertools: 10.3.0
    • mpmath: 1.3.0
    • multidict: 6.1.0
    • multiprocess: 0.70.16
    • narwhals: 1.9.0
    • nbclient: 0.10.0
    • nbconvert: 7.16.4
    • nbformat: 5.10.4
    • nest-asyncio: 1.6.0
    • networkx: 3.3
    • notebook-shim: 0.2.4
    • numpy: 2.1.1
    • omegaconf: 2.3.0
    • overrides: 7.7.0
    • packaging: 24.1
    • pandas: 2.2.3
    • pandas-stubs: 2.2.2.240909
    • pandocfilters: 1.5.1
    • parso: 0.8.4
    • pexpect: 4.9.0
    • pillow: 10.4.0
    • pip: 24.2
    • platformdirs: 3.11.0
    • plotly: 5.24.1
    • polars: 1.9.0
    • prometheus-client: 0.21.0
    • prompt-toolkit: 3.0.48
    • protobuf: 5.28.2
    • psutil: 6.0.0
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.3
    • pyarrow: 17.0.0
    • pycparser: 2.22
    • pydantic: 2.9.2
    • pydantic-core: 2.23.4
    • pygments: 2.18.0
    • pyparsing: 3.1.4
    • pyperclip: 1.9.0
    • python-dateutil: 2.9.0.post0
    • python-json-logger: 2.0.7
    • pytorch-lightning: 2.4.0
    • pytz: 2024.2
    • pyyaml: 6.0.2
    • pyzmq: 26.2.0
    • referencing: 0.35.1
    • regex: 2024.9.11
    • requests: 2.32.3
    • rfc3339-validator: 0.1.4
    • rfc3986-validator: 0.1.1
    • rich: 13.9.2
    • rpds-py: 0.20.0
    • safetensors: 0.4.5
    • scikit-learn: 1.5.2
    • scipy: 1.14.1
    • seaborn: 0.13.2
    • send2trash: 1.8.3
    • sentry-sdk: 2.15.0
    • setproctitle: 1.3.3
    • setuptools: 75.1.0
    • six: 1.16.0
    • sixelcrop: 0.1.8
    • smmap: 5.0.1
    • sniffio: 1.3.1
    • soupsieve: 2.6
    • stack-data: 0.6.3
    • sympy: 1.13.3
    • tenacity: 9.0.0
    • tensorboardx: 2.6.2.2
    • terminado: 0.18.1
    • threadpoolctl: 3.5.0
    • timg: 1.1.6
    • tinycss2: 1.3.0
    • tokenizers: 0.20.1
    • tomli: 2.0.1
    • torch: 2.4.1
    • torchaudio: 2.4.1
    • torchmetrics: 1.4.2
    • torchvision: 0.19.1
    • tornado: 6.4.1
    • tqdm: 4.66.5
    • traitlets: 5.14.3
    • transformers: 4.45.2
    • typeguard: 4.3.0
    • types-python-dateutil: 2.9.0.20241003
    • types-pytz: 2024.2.0.20241003
    • typeshed-client: 2.7.0
    • typing-extensions: 4.12.2
    • tzdata: 2024.2
    • uc-micro-py: 1.0.3
    • universal-pathlib: 0.2.5
    • uri-template: 1.3.0
    • urllib3: 2.2.3
    • wandb: 0.18.3
    • wcwidth: 0.2.13
    • webcolors: 24.8.0
    • webencodings: 0.5.1
    • websocket-client: 1.8.0
    • wheel: 0.44.0
    • widgetsnbextension: 4.0.13
    • xxhash: 3.5.0
    • yarl: 1.13.1
    • zipp: 3.19.2
  • System:
    • OS: Darwin
    • architecture:
      • 64bit
    • processor: arm
    • python: 3.12.6
    • release: 24.0.0
    • version: Darwin Kernel Version 24.0.0: Tue Sep 24 23:39:07 PDT 2024; root:xnu-11215.1.12~1/RELEASE_ARM64_T6000

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions