-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Milestone
Description
Bug description
typing of batch_indices
in BasePredictionWriter
should probably not be optional
I implemented a custom PredictionWriter
extending BasePredictionWriter
.
In every use case I'm currently aware of, batch_indices
should not be None
.
Thus when fully typing my implementation with a required Sequence[any]
type, my mypy
gets upset.
How to reproduce the bug
class InferenceResultsCSVWriter(BasePredictionWriter):
"""
A class which writes the model's predictions to CSV during inference time
"""
def __init__(self, output_dir, write_interval):
super().__init__(write_interval)
self.output_dir = output_dir
def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx):
raise NotImplementedError("Writing predictions on batch end is not yet implemented")
return
def write_on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, predictions: Sequence[Any], batch_indices: Sequence[Any]):
logging.info("Writing predictions to CSV")
predictions_dfs = []
for batch_idx in range(len(predictions[0])):
predictions_df = pd.DataFrame({
"index": batch_indices[0][batch_idx],
"pred": predictions[0][batch_idx],})
predictions_dfs.append(predictions_df)
predictions_df = pd.concat(predictions_dfs)
predictions_df.to_csv(os.path.join(self.output_dir, "predictions.csv"), index=False)
### Error messages and logs
### Environment
* CUDA:
- GPU:
- GeForce RTX 3080
- available: True
- version: 11.3
* Lightning:
- lightning-utilities: 0.4.2
- pytorch-lightning: 1.8.4.post0
- torch: 1.12.1+cu113
- torchmetrics: 0.11.0
* Packages:
- absl-py: 1.3.0
- aiohttp: 3.8.3
- aiosignal: 1.3.1
- alabaster: 0.7.12
- alembic: 1.8.1
- antlr4-python3-runtime: 4.9.3
- argh: 0.26.2
- arrow: 1.2.3
- asttokens: 2.2.1
- async-timeout: 4.0.2
- attrs: 22.1.0
- autopage: 0.5.1
- babel: 2.11.0
- backcall: 0.2.0
- binaryornot: 0.4.4
- black: 22.12.0
- bokeh: 3.0.3
- boto3: 1.26.27
- boto3-stubs: 1.26.27
- botocore: 1.29.27
- botocore-stubs: 1.29.27
- brotli: 1.0.9
- build: 0.9.0
- cached-property: 1.5.2
- cachetools: 5.2.0
- cerberus: 1.3.4
- certifi: 2022.12.7
- cffi: 1.15.1
- cfgv: 3.3.1
- chardet: 5.1.0
- charset-normalizer: 2.1.1
- click: 8.1.3
- cliff: 4.1.0
- cloudpickle: 2.2.0
- cmaes: 0.9.0
- cmd2: 2.4.2
- colorama: 0.4.6
- colorlog: 6.7.0
- contourpy: 1.0.6
- cookiecutter: 2.1.1
- coverage: 6.5.0
- cruft: 2.11.1
- cryptography: 38.0.4
- cssselect2: 0.7.0
- cycler: 0.11.0
- darglint: 1.8.1
- dask: 2022.12.0
- decorator: 5.1.1
- distlib: 0.3.6
- distributed: 2022.12.0
- docker-pycreds: 0.4.0
- docutils: 0.17.1
- exceptiongroup: 1.0.4
- execnet: 1.9.0
- executing: 1.2.0
- fastcore: 1.5.27
- filelock: 3.8.2
- flake8: 6.0.0
- flake8-annotations: 2.9.1
- flake8-black: 0.3.5
- flake8-bugbear: 22.12.6
- flake8-docstrings: 1.6.0
- fonttools: 4.38.0
- frozenlist: 1.3.3
- fsspec: 2022.11.0
- gitchangelog: 3.0.4
- gitdb: 4.0.10
- gitpython: 3.1.29
- google-auth: 2.15.0
- google-auth-oauthlib: 0.4.6
- greenlet: 2.0.1
- grpcio: 1.51.1
- h5py: 3.7.0
- hdf5storage: 0.1.18
- heapdict: 1.0.1
- html5lib: 1.1
- hydra-core: 1.3.0
- identify: 2.5.9
- idna: 3.4
- imagesize: 1.4.1
- importlib-metadata: 4.13.0
- importlib-resources: 5.10.1
- iniconfig: 1.1.1
- invoke: 1.7.3
- ipython: 8.7.0
- isort: 5.11.2
- jedi: 0.18.2
- jinja2: 3.1.2
- jinja2-time: 0.2.0
- jmespath: 1.0.1
- joblib: 1.2.0
- kiwisolver: 1.4.4
- lightgbm: 3.3.3
- lightning-utilities: 0.4.2
- locket: 1.0.0
- lovely-numpy: 0.2.2
- lovely-tensors: 0.1.10
- mako: 1.2.4
- markdown: 3.4.1
- markupsafe: 2.1.1
- matplotlib: 3.3.4
- matplotlib-inline: 0.1.6
- mccabe: 0.7.0
- mne: 0.23.4
- moto: 4.0.11
- msgpack: 1.0.4
- multidict: 6.0.3
- mypy: 0.991
- mypy-boto3-ec2: 1.26.23
- mypy-boto3-ecs: 1.26.22
- mypy-boto3-iam: 1.26.0.post1
- mypy-boto3-logs: 1.26.27
- mypy-boto3-s3: 1.26.0.post1
- mypy-extensions: 0.4.3
- nodeenv: 1.7.0
- numpy: 1.20.3
- oauthlib: 3.2.2
- omegaconf: 2.3.0
- optuna: 3.0.4
- packaging: 22.0
- pandas: 1.3.5
- parso: 0.8.3
- partd: 1.3.0
- pathspec: 0.10.3
- pathtools: 0.1.2
- pbr: 5.11.0
- pep517: 0.13.0
- pep8-naming: 0.13.2
- pexpect: 4.8.0
- pickleshare: 0.7.5
- pillow: 9.3.0
- pip: 22.3.1
- pipenv: 2022.11.30
- pipenv-setup-for-neurohelp: 3.4.6
- pipfile: 0.0.2
- platformdirs: 2.6.0
- plette: 0.4.2
- pluggy: 1.0.0
- pre-commit: 2.20.0
- prettytable: 3.5.0
- promise: 2.3
- prompt-toolkit: 3.0.36
- protobuf: 3.20.1
- psutil: 5.9.4
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- py: 1.11.0
- pyasn1: 0.4.8
- pyasn1-modules: 0.2.8
- pycodestyle: 2.10.0
- pycparser: 2.21
- pydantic: 1.10.2
- pydocstyle: 6.1.1
- pydyf: 0.5.0
- pyflakes: 3.0.1
- pygments: 2.13.0
- pyparsing: 3.0.9
- pyperclip: 1.8.2
- pyphen: 0.13.2
- pyproject-api: 1.2.1
- pytest: 7.2.0
- pytest-cov: 4.0.0
- pytest-forked: 1.4.0
- pytest-xdist: 3.1.0
- python-dateutil: 2.8.2
- python-dotenv: 0.19.2
- python-slugify: 7.0.0
- pytorch-lightning: 1.8.4.post0
- pytz: 2022.6
- pyyaml: 6.0
- requests: 2.28.1
- requests-oauthlib: 1.3.1
- requirementslib: 2.2.1
- responses: 0.22.0
- rsa: 4.9
- s3transfer: 0.6.0
- scikit-learn: 1.2.0
- scipy: 1.8.1
- seaborn: 0.11.2
- sentry-sdk: 1.11.1
- setproctitle: 1.3.2
- setuptools: 65.5.1
- shortuuid: 1.0.11
- six: 1.16.0
- smmap: 5.0.0
- snowballstemmer: 2.2.0
- sortedcontainers: 2.4.0
- sphinx: 5.3.0
- sphinx-autodoc-typehints: 1.19.5
- sphinx-rtd-theme: 1.1.1
- sphinxcontrib-applehelp: 1.0.2
- sphinxcontrib-devhelp: 1.0.2
- sphinxcontrib-htmlhelp: 2.0.0
- sphinxcontrib-jsmath: 1.0.1
- sphinxcontrib-qthelp: 1.0.3
- sphinxcontrib-serializinghtml: 1.1.5
- sqlalchemy: 1.4.45
- stack-data: 0.6.2
- stevedore: 4.1.1
- tblib: 1.7.0
- tensorboard: 2.11.0
- tensorboard-data-server: 0.6.1
- tensorboard-plugin-wit: 1.8.1
- tensorboardx: 2.5.1
- text-unidecode: 1.3
- threadpoolctl: 3.1.0
- tinycss2: 1.2.1
- toml: 0.10.2
- tomli: 2.0.1
- tomlkit: 0.11.6
- toolz: 0.12.0
- torch: 1.12.1+cu113
- torchmetrics: 0.11.0
- tornado: 6.2
- tox: 4.0.9
- tqdm: 4.64.1
- traitlets: 5.7.1
- typer: 0.6.1
- types-awscrt: 0.16.1
- types-s3transfer: 0.6.0.post5
- types-six: 1.16.21.4
- types-toml: 0.10.8.1
- typing-extensions: 4.4.0
- urllib3: 1.26.13
- versioneer: 0.22
- virtualenv: 20.17.1
- virtualenv-clone: 0.5.7
- vistir: 0.7.5
- wandb: 0.13.6
- watchdog: 2.2.0
- wcwidth: 0.2.5
- weasyprint: 57.1
- webencodings: 0.5.1
- werkzeug: 2.2.2
- wheel: 0.38.4
- xmltodict: 0.13.0
- xyzservices: 2022.9.0
- yarl: 1.8.2
- zict: 2.2.0
- zipp: 3.11.0
- zopfli: 0.2.2
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.5
- version: #54~20.04.1-Ubuntu SMP Sat Mar 20 13:40:25 UTC 2021
### More info
_No response_
Metadata
Metadata
Assignees
Labels
No labels