-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
Description
Bug description
Subclassing the BasePredictionWriter
for custom functionality results in Pylance complaining about incorrect type.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
import lightning as L
from lightning.pytorch.callbacks import BasePredictionWriter
if TYPE_CHECKING:
import polars as pl
from lightning.pytorch import LightningModule, Trainer
from torch import Tensor
class ParquetWriter(BasePredictionWriter):
"""
Callback for writing predictions to Parquet files.
Parameters
----------
output_dir
The directory where the parquet files will be written.
write_interval
The interval at which the predictions will be written.
"""
def __init__(self, output_dir: str, write_interval: Literal["batch"]) -> None:
super().__init__(write_interval)
self.output_dir = Path(output_dir)
def write_on_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule, # noqa: ARG002
prediction: pl.DataFrame,
batch_indices: Tensor, # noqa: ARG002
batch: dict[str, Any], # noqa: ARG002
batch_idx: int,
dataloader_idx: int, # noqa: ARG002
) -> None:
"""Write the prediction to a parquet file."""
prediction.write_parquet(
self.output_dir / f"{trainer.global_rank}{batch_idx}.parquet",
)
callbacks = [
ParquetWriter(
output_dir="/tmp",
write_interval="batch",
),
]
trainer = L.Trainer(
callbacks=callbacks, <----- Pylance(reportArgumentType)
)
Error messages and logs
Argument of type "list[ParquetWriter]" cannot be assigned to parameter "callbacks" of type "List[Callback] | Callback | None" in function "__init__"
Type "list[ParquetWriter]" is not assignable to type "List[Callback] | Callback | None"
"list[ParquetWriter]" is not assignable to "List[Callback]"
Type parameter "_T@list" is invariant, but "ParquetWriter" is not the same as "Callback"
Consider switching from "list" to "Sequence" which is covariant
"list[ParquetWriter]" is not assignable to "Callback"
"list[ParquetWriter]" is not assignable to "None"
Environment
Current environment
- CUDA:
- GPU: None
- available: False
- version: None
- Lightning:
- lightning: 2.4.0
- lightning-utilities: 0.11.7
- pytorch-lightning: 2.4.0
- torch: 2.4.1
- torchaudio: 2.4.1
- torchmetrics: 1.4.2
- torchvision: 0.19.1
- Packages:
- aenum: 3.1.12
- aiohappyeyeballs: 2.4.3
- aiohttp: 3.10.8
- aiosignal: 1.3.1
- altair: 5.4.1
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- anyio: 4.6.0
- appnope: 0.1.4
- argon2-cffi: 23.1.0
- argon2-cffi-bindings: 21.2.0
- arrow: 1.3.0
- asttokens: 2.4.1
- async-lru: 2.0.4
- attrs: 24.2.0
- autocommand: 2.2.2
- babel: 2.16.0
- backports.tarfile: 1.2.0
- beautifulsoup4: 4.12.3
- bitsandbytes: 0.42.0
- bleach: 6.1.0
- certifi: 2024.8.30
- cffi: 1.17.1
- charset-normalizer: 3.3.2
- click: 8.1.7
- comm: 0.2.2
- contourpy: 1.3.0
- crispron: 3.0
- cycler: 0.12.1
- datasets: 3.0.1
- debugpy: 1.8.6
- decorator: 5.1.1
- defusedxml: 0.7.1
- dill: 0.3.8
- docker-pycreds: 0.4.0
- docstring-parser: 0.16
- euporie: 2.8.3
- executing: 2.1.0
- fastjsonschema: 2.20.0
- filelock: 3.16.1
- flatlatex: 0.15
- fonttools: 4.54.1
- fqdn: 1.5.1
- frozenlist: 1.4.1
- fsspec: 2024.6.1
- gitdb: 4.0.11
- gitpython: 3.1.43
- h11: 0.14.0
- httpcore: 1.0.6
- httpx: 0.27.2
- huggingface-hub: 0.25.1
- hydra-core: 1.3.2
- idna: 3.10
- imagesize: 1.4.1
- importlib-metadata: 8.0.0
- importlib-resources: 6.4.5
- inflect: 7.3.1
- ipykernel: 6.29.5
- ipython: 8.28.0
- ipywidgets: 8.1.5
- isoduration: 20.11.0
- itables: 2.2.2
- jaraco.collections: 5.1.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jedi: 0.19.1
- jinja2: 3.1.4
- joblib: 1.4.2
- json5: 0.9.25
- jsonargparse: 4.33.1
- jsonpointer: 3.0.0
- jsonschema: 4.23.0
- jsonschema-specifications: 2023.12.1
- jupyter-client: 8.6.3
- jupyter-core: 5.7.2
- jupyter-events: 0.10.0
- jupyter-lsp: 2.2.5
- jupyter-server: 2.14.2
- jupyter-server-terminals: 0.5.3
- jupyterlab: 4.2.5
- jupyterlab-pygments: 0.3.0
- jupyterlab-server: 2.27.3
- jupyterlab-widgets: 3.0.13
- jupytext: 1.16.4
- kiwisolver: 1.4.7
- lightning: 2.4.0
- lightning-utilities: 0.11.7
- linkify-it-py: 1.0.3
- markdown-it-py: 2.2.0
- markupsafe: 2.1.5
- matplotlib: 3.9.2
- matplotlib-inline: 0.1.7
- mdit-py-plugins: 0.3.5
- mdurl: 0.1.2
- mistune: 3.0.2
- more-itertools: 10.3.0
- mpmath: 1.3.0
- multidict: 6.1.0
- multiprocess: 0.70.16
- narwhals: 1.9.0
- nbclient: 0.10.0
- nbconvert: 7.16.4
- nbformat: 5.10.4
- nest-asyncio: 1.6.0
- networkx: 3.3
- notebook-shim: 0.2.4
- numpy: 2.1.1
- omegaconf: 2.3.0
- overrides: 7.7.0
- packaging: 24.1
- pandas: 2.2.3
- pandas-stubs: 2.2.2.240909
- pandocfilters: 1.5.1
- parso: 0.8.4
- pexpect: 4.9.0
- pillow: 10.4.0
- pip: 24.2
- platformdirs: 3.11.0
- plotly: 5.24.1
- polars: 1.9.0
- prometheus-client: 0.21.0
- prompt-toolkit: 3.0.48
- protobuf: 5.28.2
- psutil: 6.0.0
- ptyprocess: 0.7.0
- pure-eval: 0.2.3
- pyarrow: 17.0.0
- pycparser: 2.22
- pydantic: 2.9.2
- pydantic-core: 2.23.4
- pygments: 2.18.0
- pyparsing: 3.1.4
- pyperclip: 1.9.0
- python-dateutil: 2.9.0.post0
- python-json-logger: 2.0.7
- pytorch-lightning: 2.4.0
- pytz: 2024.2
- pyyaml: 6.0.2
- pyzmq: 26.2.0
- referencing: 0.35.1
- regex: 2024.9.11
- requests: 2.32.3
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- rich: 13.9.2
- rpds-py: 0.20.0
- safetensors: 0.4.5
- scikit-learn: 1.5.2
- scipy: 1.14.1
- seaborn: 0.13.2
- send2trash: 1.8.3
- sentry-sdk: 2.15.0
- setproctitle: 1.3.3
- setuptools: 75.1.0
- six: 1.16.0
- sixelcrop: 0.1.8
- smmap: 5.0.1
- sniffio: 1.3.1
- soupsieve: 2.6
- stack-data: 0.6.3
- sympy: 1.13.3
- tenacity: 9.0.0
- tensorboardx: 2.6.2.2
- terminado: 0.18.1
- threadpoolctl: 3.5.0
- timg: 1.1.6
- tinycss2: 1.3.0
- tokenizers: 0.20.1
- tomli: 2.0.1
- torch: 2.4.1
- torchaudio: 2.4.1
- torchmetrics: 1.4.2
- torchvision: 0.19.1
- tornado: 6.4.1
- tqdm: 4.66.5
- traitlets: 5.14.3
- transformers: 4.45.2
- typeguard: 4.3.0
- types-python-dateutil: 2.9.0.20241003
- types-pytz: 2024.2.0.20241003
- typeshed-client: 2.7.0
- typing-extensions: 4.12.2
- tzdata: 2024.2
- uc-micro-py: 1.0.3
- universal-pathlib: 0.2.5
- uri-template: 1.3.0
- urllib3: 2.2.3
- wandb: 0.18.3
- wcwidth: 0.2.13
- webcolors: 24.8.0
- webencodings: 0.5.1
- websocket-client: 1.8.0
- wheel: 0.44.0
- widgetsnbextension: 4.0.13
- xxhash: 3.5.0
- yarl: 1.13.1
- zipp: 3.19.2
- System:
- OS: Darwin
- architecture:
- 64bit
- processor: arm
- python: 3.12.6
- release: 24.0.0
- version: Darwin Kernel Version 24.0.0: Tue Sep 24 23:39:07 PDT 2024; root:xnu-11215.1.12~1/RELEASE_ARM64_T6000
More info
No response