Skip to content

typing of batch_indices in BasePredictionWriter should probably not be optional #16049

@noamsgl

Description

@noamsgl

Bug description

typing of batch_indices in BasePredictionWriter should probably not be optional

I implemented a custom PredictionWriter extending BasePredictionWriter.

In every use case I'm currently aware of, batch_indices should not be None.
Thus when fully typing my implementation with a required Sequence[any] type, my mypy gets upset.

image

How to reproduce the bug

class InferenceResultsCSVWriter(BasePredictionWriter):
    """
    A class which writes the model's predictions to CSV during inference time
    """
    def __init__(self, output_dir, write_interval):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx):
        raise NotImplementedError("Writing predictions on batch end is not yet implemented")
        return

    def write_on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, predictions: Sequence[Any], batch_indices: Sequence[Any]):
        logging.info("Writing predictions to CSV")
        predictions_dfs = []
        for batch_idx in range(len(predictions[0])):
            predictions_df = pd.DataFrame({
                "index": batch_indices[0][batch_idx],
                "pred": predictions[0][batch_idx],})
            predictions_dfs.append(predictions_df)
        predictions_df = pd.concat(predictions_dfs)
        predictions_df.to_csv(os.path.join(self.output_dir, "predictions.csv"), index=False)


### Error messages and logs

### Environment

* CUDA:
	- GPU:
		- GeForce RTX 3080
	- available:         True
	- version:           11.3
* Lightning:
	- lightning-utilities: 0.4.2
	- pytorch-lightning: 1.8.4.post0
	- torch:             1.12.1+cu113
	- torchmetrics:      0.11.0
* Packages:
	- absl-py:           1.3.0
	- aiohttp:           3.8.3
	- aiosignal:         1.3.1
	- alabaster:         0.7.12
	- alembic:           1.8.1
	- antlr4-python3-runtime: 4.9.3
	- argh:              0.26.2
	- arrow:             1.2.3
	- asttokens:         2.2.1
	- async-timeout:     4.0.2
	- attrs:             22.1.0
	- autopage:          0.5.1
	- babel:             2.11.0
	- backcall:          0.2.0
	- binaryornot:       0.4.4
	- black:             22.12.0
	- bokeh:             3.0.3
	- boto3:             1.26.27
	- boto3-stubs:       1.26.27
	- botocore:          1.29.27
	- botocore-stubs:    1.29.27
	- brotli:            1.0.9
	- build:             0.9.0
	- cached-property:   1.5.2
	- cachetools:        5.2.0
	- cerberus:          1.3.4
	- certifi:           2022.12.7
	- cffi:              1.15.1
	- cfgv:              3.3.1
	- chardet:           5.1.0
	- charset-normalizer: 2.1.1
	- click:             8.1.3
	- cliff:             4.1.0
	- cloudpickle:       2.2.0
	- cmaes:             0.9.0
	- cmd2:              2.4.2
	- colorama:          0.4.6
	- colorlog:          6.7.0
	- contourpy:         1.0.6
	- cookiecutter:      2.1.1
	- coverage:          6.5.0
	- cruft:             2.11.1
	- cryptography:      38.0.4
	- cssselect2:        0.7.0
	- cycler:            0.11.0
	- darglint:          1.8.1
	- dask:              2022.12.0
	- decorator:         5.1.1
	- distlib:           0.3.6
	- distributed:       2022.12.0
	- docker-pycreds:    0.4.0
	- docutils:          0.17.1
	- exceptiongroup:    1.0.4
	- execnet:           1.9.0
	- executing:         1.2.0
	- fastcore:          1.5.27
	- filelock:          3.8.2
	- flake8:            6.0.0
	- flake8-annotations: 2.9.1
	- flake8-black:      0.3.5
	- flake8-bugbear:    22.12.6
	- flake8-docstrings: 1.6.0
	- fonttools:         4.38.0
	- frozenlist:        1.3.3
	- fsspec:            2022.11.0
	- gitchangelog:      3.0.4
	- gitdb:             4.0.10
	- gitpython:         3.1.29
	- google-auth:       2.15.0
	- google-auth-oauthlib: 0.4.6
	- greenlet:          2.0.1
	- grpcio:            1.51.1
	- h5py:              3.7.0
	- hdf5storage:       0.1.18
	- heapdict:          1.0.1
	- html5lib:          1.1
	- hydra-core:        1.3.0
	- identify:          2.5.9
	- idna:              3.4
	- imagesize:         1.4.1
	- importlib-metadata: 4.13.0
	- importlib-resources: 5.10.1
	- iniconfig:         1.1.1
	- invoke:            1.7.3
	- ipython:           8.7.0
	- isort:             5.11.2
	- jedi:              0.18.2
	- jinja2:            3.1.2
	- jinja2-time:       0.2.0
	- jmespath:          1.0.1
	- joblib:            1.2.0
	- kiwisolver:        1.4.4
	- lightgbm:          3.3.3
	- lightning-utilities: 0.4.2
	- locket:            1.0.0
	- lovely-numpy:      0.2.2
	- lovely-tensors:    0.1.10
	- mako:              1.2.4
	- markdown:          3.4.1
	- markupsafe:        2.1.1
	- matplotlib:        3.3.4
	- matplotlib-inline: 0.1.6
	- mccabe:            0.7.0
	- mne:               0.23.4
	- moto:              4.0.11
	- msgpack:           1.0.4
	- multidict:         6.0.3
	- mypy:              0.991
	- mypy-boto3-ec2:    1.26.23
	- mypy-boto3-ecs:    1.26.22
	- mypy-boto3-iam:    1.26.0.post1
	- mypy-boto3-logs:   1.26.27
	- mypy-boto3-s3:     1.26.0.post1
	- mypy-extensions:   0.4.3
	- nodeenv:           1.7.0
	- numpy:             1.20.3
	- oauthlib:          3.2.2
	- omegaconf:         2.3.0
	- optuna:            3.0.4
	- packaging:         22.0
	- pandas:            1.3.5
	- parso:             0.8.3
	- partd:             1.3.0
	- pathspec:          0.10.3
	- pathtools:         0.1.2
	- pbr:               5.11.0
	- pep517:            0.13.0
	- pep8-naming:       0.13.2
	- pexpect:           4.8.0
	- pickleshare:       0.7.5
	- pillow:            9.3.0
	- pip:               22.3.1
	- pipenv:            2022.11.30
	- pipenv-setup-for-neurohelp: 3.4.6
	- pipfile:           0.0.2
	- platformdirs:      2.6.0
	- plette:            0.4.2
	- pluggy:            1.0.0
	- pre-commit:        2.20.0
	- prettytable:       3.5.0
	- promise:           2.3
	- prompt-toolkit:    3.0.36
	- protobuf:          3.20.1
	- psutil:            5.9.4
	- ptyprocess:        0.7.0
	- pure-eval:         0.2.2
	- py:                1.11.0
	- pyasn1:            0.4.8
	- pyasn1-modules:    0.2.8
	- pycodestyle:       2.10.0
	- pycparser:         2.21
	- pydantic:          1.10.2
	- pydocstyle:        6.1.1
	- pydyf:             0.5.0
	- pyflakes:          3.0.1
	- pygments:          2.13.0
	- pyparsing:         3.0.9
	- pyperclip:         1.8.2
	- pyphen:            0.13.2
	- pyproject-api:     1.2.1
	- pytest:            7.2.0
	- pytest-cov:        4.0.0
	- pytest-forked:     1.4.0
	- pytest-xdist:      3.1.0
	- python-dateutil:   2.8.2
	- python-dotenv:     0.19.2
	- python-slugify:    7.0.0
	- pytorch-lightning: 1.8.4.post0
	- pytz:              2022.6
	- pyyaml:            6.0
	- requests:          2.28.1
	- requests-oauthlib: 1.3.1
	- requirementslib:   2.2.1
	- responses:         0.22.0
	- rsa:               4.9
	- s3transfer:        0.6.0
	- scikit-learn:      1.2.0
	- scipy:             1.8.1
	- seaborn:           0.11.2
	- sentry-sdk:        1.11.1
	- setproctitle:      1.3.2
	- setuptools:        65.5.1
	- shortuuid:         1.0.11
	- six:               1.16.0
	- smmap:             5.0.0
	- snowballstemmer:   2.2.0
	- sortedcontainers:  2.4.0
	- sphinx:            5.3.0
	- sphinx-autodoc-typehints: 1.19.5
	- sphinx-rtd-theme:  1.1.1
	- sphinxcontrib-applehelp: 1.0.2
	- sphinxcontrib-devhelp: 1.0.2
	- sphinxcontrib-htmlhelp: 2.0.0
	- sphinxcontrib-jsmath: 1.0.1
	- sphinxcontrib-qthelp: 1.0.3
	- sphinxcontrib-serializinghtml: 1.1.5
	- sqlalchemy:        1.4.45
	- stack-data:        0.6.2
	- stevedore:         4.1.1
	- tblib:             1.7.0
	- tensorboard:       2.11.0
	- tensorboard-data-server: 0.6.1
	- tensorboard-plugin-wit: 1.8.1
	- tensorboardx:      2.5.1
	- text-unidecode:    1.3
	- threadpoolctl:     3.1.0
	- tinycss2:          1.2.1
	- toml:              0.10.2
	- tomli:             2.0.1
	- tomlkit:           0.11.6
	- toolz:             0.12.0
	- torch:             1.12.1+cu113
	- torchmetrics:      0.11.0
	- tornado:           6.2
	- tox:               4.0.9
	- tqdm:              4.64.1
	- traitlets:         5.7.1
	- typer:             0.6.1
	- types-awscrt:      0.16.1
	- types-s3transfer:  0.6.0.post5
	- types-six:         1.16.21.4
	- types-toml:        0.10.8.1
	- typing-extensions: 4.4.0
	- urllib3:           1.26.13
	- versioneer:        0.22
	- virtualenv:        20.17.1
	- virtualenv-clone:  0.5.7
	- vistir:            0.7.5
	- wandb:             0.13.6
	- watchdog:          2.2.0
	- wcwidth:           0.2.5
	- weasyprint:        57.1
	- webencodings:      0.5.1
	- werkzeug:          2.2.2
	- wheel:             0.38.4
	- xmltodict:         0.13.0
	- xyzservices:       2022.9.0
	- yarl:              1.8.2
	- zict:              2.2.0
	- zipp:              3.11.0
	- zopfli:            0.2.2
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- ELF
	- processor:         x86_64
	- python:            3.8.5
	- version:           #54~20.04.1-Ubuntu SMP Sat Mar 20 13:40:25 UTC 2021


### More info

_No response_

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions