Skip to content

Error while merging hparams when using LightningCLI and YAMLย #20182

@cgebbe

Description

@cgebbe

Bug description

The minimal example below throws the error RuntimeError: Error while merging hparams: the keys ['_class_path'] are present in both the LightningModule's and LightningDataModule's hparams but have different values.

I though this was supposed to work. Would really appreciate workaround tips (that also work with checkpointing) or a fix.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

Run the code below using e.g. python main.py --config config.yaml

# config.yaml
data:
  class_path: MNISTDataModule

model:
  class_path: LitAutoEncoder
# main.py
import os
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
from lightning.pytorch import cli
from torch.utils.data import random_split, DataLoader
import torch
from torchvision import transforms


class LitAutoEncoder(L.LightningModule):
    """From https://lightning.ai/docs/pytorch/stable/starter/introduction.html"""

    def __init__(self, dim: int = 64):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, dim), nn.ReLU(), nn.Linear(dim, 3)
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, dim), nn.ReLU(), nn.Linear(dim, 28 * 28)
        )

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)


class MNISTDataModule(L.LightningDataModule):
    """From https://lightning.ai/docs/pytorch/stable/data/datamodule.html"""

    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.save_hyperparameters()
        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(
                self.data_dir, train=False, transform=self.transform
            )

        if stage == "predict":
            self.mnist_predict = MNIST(
                self.data_dir, train=False, transform=self.transform
            )

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)


cli.LightningCLI()

Error messages and logs

RuntimeError: Error while merging hparams: the keys ['_class_path'] are present in both the LightningModule's and LightningDataModule's hparams but have different values.

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A10G
    - available: True
    - version: 12.1
  • Lightning:
    - lightning: 2.4.0
    - lightning-utilities: 0.11.4
    - pytorch-lightning: 2.3.3
    - torch: 2.3.1
    - torchdata: 0.7.1
    - torchmetrics: 1.0.3
    - torchsummary: 1.5.1
    - torchvision: 0.18.1
  • Packages:
    - absl-py: 2.1.0
    - affine: 2.4.0
    - aiobotocore: 2.13.0
    - aiohttp: 3.9.5
    - aioitertools: 0.11.0
    - aiosignal: 1.3.1
    - albucore: 0.0.12
    - albumentations: 1.4.11
    - altair: 5.3.0
    - annotated-types: 0.7.0
    - ansi2html: 1.9.1
    - ansicolors: 1.1.8
    - antlr4-python3-runtime: 4.9.3
    - anyio: 4.4.0
    - appdirs: 1.4.4
    - argon2-cffi: 23.1.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.3.0
    - asciitree: 0.3.3
    - assertpy: 1.1
    - asttokens: 2.4.1
    - astunparse: 1.6.3
    - async-lru: 2.0.4
    - attrs: 23.2.0
    - automation-api-gateway-client: 12.0.165
    - az-annotation-io: 0.18.8
    - az-cp-aws-utils: 1.2.5
    - az-cp-datadictionary-definitions: 0.2024.119.post1
    - az-cp-drclib: 0.1.0
    - az-cp-holoviews-compressed-rgb: 0.0.1rc3
    - az-cp-imagekit-ventana-bif: 0.2.2
    - az-cp-logging: 0.15.0
    - az-cp-ooportal: 1.4.7
    - az-cp-pathviz: 0.0.1rc7
    - az-cp-pita: 1.10.1
    - az-cp-predictino-container: 3.10.3
    - az-drc2polygons: 2.2.2
    - az-git-utils: 0.14.0
    - babel: 2.15.0
    - beautifulsoup4: 4.12.3
    - bleach: 6.1.0
    - blessed: 1.20.0
    - blinker: 1.8.2
    - bokeh: 3.4.1
    - boto3: 1.34.106
    - botocore: 1.34.106
    - bpython: 0.24
    - braceexpand: 0.1.7
    - cachetools: 5.3.3
    - cerberus: 1.3.5
    - certifi: 2024.2.2
    - cffi: 1.16.0
    - cfgv: 3.4.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - click-plugins: 1.1.1
    - cligj: 0.7.2
    - cloudpickle: 3.0.0
    - cmake: 3.25.0
    - colorama: 0.4.6
    - colorcet: 3.1.0
    - comm: 0.2.2
    - contourpy: 1.2.1
    - coreg-tools: 0.3.0
    - cubinlinker-cu11: 0.3.0.post2
    - cuda-python: 11.8.3
    - cudf-cu11: 24.6.1
    - cuml-cu11: 24.6.1
    - cupy-cuda11x: 13.2.0
    - curtsies: 0.4.2
    - cwcwidth: 0.1.9
    - cycler: 0.12.1
    - dash: 2.17.0
    - dash-core-components: 2.0.0
    - dash-html-components: 2.0.0
    - dash-table: 5.0.0
    - dask: 2024.5.1
    - dask-cuda: 24.6.0
    - dask-cudf-cu11: 24.6.1
    - dask-expr: 1.1.1
    - datashader: 0.16.2
    - debugpy: 1.8.1
    - decorator: 5.1.1
    - definiens-autocli: 4.5.3
    - definiens-ia-algorithms: 3.39.1
    - definiens-imagekit: 5.6.6
    - definiens-parallel: 1.3.0
    - defusedxml: 0.7.1
    - deprecated: 1.2.14
    - distlib: 0.3.8
    - distributed: 2024.5.1
    - distributed-ucxx-cu11: 0.38.0
    - dm-tree: 0.1.8
    - dnspython: 2.6.1
    - docker-pycreds: 0.4.0
    - docstring-parser: 0.16
    - ec2-metadata: 2.13.0
    - editor: 1.6.6
    - entrypoints: 0.4
    - eval-type-backport: 0.2.0
    - executing: 2.0.1
    - fancycompleter: 0.9.1
    - fasteners: 0.19
    - fastjsonschema: 2.19.1
    - fastrlock: 0.8.2
    - filelock: 3.14.0
    - fiona: 1.9.6
    - flask: 3.0.3
    - flatbuffers: 24.3.25
    - fonttools: 4.52.4
    - fqdn: 1.5.1
    - frozendict: 2.4.4
    - frozenlist: 1.4.1
    - fsspec: 2024.5.0
    - fvcore: 0.1.5.post20221221
    - gast: 0.6.0
    - geopandas: 1.0.1
    - geopolars: 0.1.0a4
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - google-auth: 2.32.0
    - google-auth-oauthlib: 1.0.0
    - google-pasta: 0.2.0
    - greenlet: 3.0.3
    - grpcio: 1.64.1
    - h11: 0.14.0
    - h5py: 3.11.0
    - holoviews: 1.19.0
    - httpcore: 1.0.5
    - httpx: 0.27.0
    - huggingface-hub: 0.24.0
    - hvplot: 0.10.0
    - identify: 2.5.36
    - idna: 3.7
    - imagecodecs: 2024.1.1
    - imageio: 2.34.1
    - importlib-metadata: 7.1.0
    - iniconfig: 2.0.0
    - inquirer: 3.2.4
    - ioda: 0.19.1
    - ioda-readout-service-client: 0.0.5
    - ioda-result-service-client: 0.6.1
    - iopath: 0.1.10
    - ipp: 2021.4.0
    - ipykernel: 6.29.4
    - ipython: 8.24.0
    - ipytree: 0.2.2
    - ipywidgets: 8.1.3
    - isoduration: 20.11.0
    - itables: 2.1.0
    - itsdangerous: 2.2.0
    - jax: 0.4.30
    - jaxlib: 0.4.30
    - jedi: 0.19.1
    - jinja2: 3.1.4
    - jmespath: 1.0.1
    - joblib: 1.4.2
    - json5: 0.9.25
    - jsonargparse: 4.29.0
    - jsonpointer: 2.4
    - jsonschema: 4.22.0
    - jsonschema-specifications: 2023.12.1
    - jupyter: 1.0.0
    - jupyter-bokeh: 4.0.5
    - jupyter-client: 8.6.2
    - jupyter-console: 6.6.3
    - jupyter-core: 5.7.2
    - jupyter-dash: 0.4.2
    - jupyter-events: 0.10.0
    - jupyter-lsp: 2.2.5
    - jupyter-server: 2.14.1
    - jupyter-server-terminals: 0.5.3
    - jupyterlab: 4.2.1
    - jupyterlab-pygments: 0.3.0
    - jupyterlab-server: 2.27.2
    - jupyterlab-vim: 4.1.3
    - jupyterlab-widgets: 3.0.11
    - kaleido: 0.2.1
    - keras: 2.12.0
    - kiwisolver: 1.4.5
    - lazy-loader: 0.4
    - leb128: 1.0.7
    - libclang: 18.1.1
    - libucx-cu11: 1.15.0.post1
    - lightning: 2.4.0
    - lightning-utilities: 0.11.4
    - linkify-it-py: 2.0.3
    - lit: 15.0.7
    - litdata: 0.2.16
    - llvmlite: 0.42.0
    - locket: 1.0.0
    - loguru: 0.7.2
    - lxml: 4.9.4
    - markdown: 3.6
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.9.0
    - matplotlib-inline: 0.1.7
    - mdit-py-plugins: 0.4.1
    - mdurl: 0.1.2
    - mistune: 3.0.2
    - ml-dtypes: 0.4.0
    - mpmath: 1.3.0
    - msgpack: 1.0.8
    - multidict: 6.0.5
    - multipledispatch: 1.0.0
    - namex: 0.0.8
    - nbclient: 0.10.0
    - nbconvert: 7.16.4
    - nbformat: 5.10.4
    - nest-asyncio: 1.6.0
    - networkx: 3.3
    - newrelic: 8.11.0
    - nodeenv: 1.9.0
    - notebook: 7.2.0
    - notebook-cgebbe: 0.1.0
    - notebook-shim: 0.2.4
    - nrai-wrapper: 2.19.339839
    - numba: 0.59.1
    - numcodecs: 0.12.1
    - numpy: 1.26.4
    - nvidia-cublas-cu12: 12.1.3.1
    - nvidia-cuda-cupti-cu12: 12.1.105
    - nvidia-cuda-nvrtc-cu12: 12.1.105
    - nvidia-cuda-runtime-cu12: 12.1.105
    - nvidia-cudnn-cu12: 8.9.2.26
    - nvidia-cufft-cu12: 11.0.2.54
    - nvidia-curand-cu12: 10.3.2.106
    - nvidia-cusolver-cu12: 11.4.5.107
    - nvidia-cusparse-cu12: 12.1.0.106
    - nvidia-dali-cuda120: 1.39.0
    - nvidia-nccl-cu12: 2.20.5
    - nvidia-nvimgcodec-cu12: 0.2.0.7
    - nvidia-nvjitlink-cu12: 12.5.82
    - nvidia-nvtx-cu12: 12.1.105
    - nvtx: 0.2.10
    - oauthlib: 3.2.2
    - ome-types: 0.5.1.post1
    - omegaconf: 2.3.0
    - opencv-python-headless: 4.9.0.80
    - openslide-python: 1.3.1
    - opt-einsum: 3.3.0
    - optree: 0.12.1
    - outcome: 1.3.0.post0
    - overrides: 7.7.0
    - packaging: 24.0
    - pandas: 2.1.4
    - pandocfilters: 1.5.1
    - panel: 1.4.4
    - papermill: 2.6.0
    - param: 2.1.0
    - parso: 0.8.4
    - partd: 1.4.2
    - pdbpp: 0.10.3
    - pexpect: 4.9.0
    - pillow: 10.3.0
    - pip: 24.1.2
    - pivottablejs: 0.9.0
    - platformdirs: 4.2.2
    - plotly: 5.22.0
    - pluggy: 1.5.0
    - polars: 0.20.30
    - portal-test-utils: 8.100.0
    - portalocker: 2.10.1
    - pre-commit: 3.7.1
    - prometheus-client: 0.20.0
    - prompt-toolkit: 3.0.45
    - protobuf: 3.20.0
    - psutil: 5.9.8
    - psycopg2-binary: 2.9.9
    - ptpython: 3.0.27
    - ptxcompiler-cu11: 0.8.1.post1
    - ptyprocess: 0.7.0
    - pudb: 2024.1
    - pure-eval: 0.2.2
    - pyarrow: 16.1.0
    - pyasn1: 0.6.0
    - pyasn1-modules: 0.4.0
    - pycparser: 2.22
    - pyct: 0.5.0
    - pydantic: 2.7.3
    - pydantic-compat: 0.1.2
    - pydantic-core: 2.18.4
    - pydeck: 0.9.1
    - pygments: 2.18.0
    - pylibraft-cu11: 24.6.0
    - pymongo: 4.8.0
    - pynvml: 11.4.1
    - pyogrio: 0.9.0
    - pyotp: 2.9.0
    - pyparsing: 3.1.2
    - pypeln: 0.4.9
    - pyportal: 2.33.2
    - pyproj: 3.6.1
    - pyrepl: 0.9.0
    - pysocks: 1.7.1
    - pytest: 8.2.2
    - python-dateutil: 2.9.0.post0
    - python-dotenv: 1.0.1
    - python-gitlab: 3.15.0
    - python-json-logger: 2.0.7
    - python-magic: 0.4.27
    - python-on-whales: 0.71.0
    - pytorch-lightning: 2.3.3
    - pytz: 2024.1
    - pyviz-comms: 3.0.2
    - pyxdg: 0.28
    - pyyaml: 6.0.1
    - pyzmq: 26.0.3
    - qtconsole: 5.5.2
    - qtpy: 2.4.1
    - raft-dask-cu11: 24.6.0
    - rai: 2.4.1
    - rapids-dask-dependency: 24.6.0
    - rasterio: 1.3.10
    - ray: 2.31.0
    - rdi: 2.19.339839
    - readchar: 4.1.0
    - referencing: 0.35.1
    - requests: 2.32.2
    - requests-oauthlib: 2.0.0
    - requests-toolbelt: 1.0.0
    - retrying: 1.3.4
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rich: 13.7.1
    - rmm-cu11: 24.6.0
    - rpds-py: 0.18.1
    - rsa: 4.9
    - rtree: 1.2.0
    - ruamel.yaml: 0.18.6
    - ruamel.yaml.clib: 0.2.8
    - ruff: 0.5.1
    - runs: 1.2.2
    - s3cmd: 2.4.0
    - s3fs: 2024.5.0
    - s3transfer: 0.10.1
    - safetensors: 0.4.3
    - scikit-image: 0.23.2
    - scikit-learn: 1.5.1
    - scipy: 1.13.1
    - seaborn: 0.13.2
    - segment-anything: 1.0
    - selenium: 4.22.0
    - semantic-segmentation: 0.1.0
    - send2trash: 1.8.3
    - sentry-sdk: 2.6.0
    - setproctitle: 1.3.3
    - setuptools: 63.4.3
    - shapely: 2.0.4
    - shellingham: 1.5.4
    - six: 1.16.0
    - smmap: 5.0.1
    - sniffio: 1.3.1
    - snuggs: 1.4.7
    - sortedcontainers: 2.4.0
    - soupsieve: 2.5
    - spatialpandas: 0.4.10
    - sqlalchemy: 2.0.31
    - stack-data: 0.6.3
    - stopit: 1.1.2
    - streamlit: 1.36.0
    - submitit: 1.5.1
    - sympy: 1.12
    - tabulate: 0.9.0
    - tblib: 3.0.0
    - tenacity: 8.3.0
    - tensorboard: 2.12.3
    - tensorboard-data-server: 0.7.2
    - tensorflow: 2.12.1
    - tensorflow-estimator: 2.12.0
    - tensorflow-io: 0.37.1
    - tensorflow-io-gcs-filesystem: 0.37.1
    - termcolor: 2.4.0
    - terminado: 0.18.1
    - threadpoolctl: 3.5.0
    - tifffile: 2024.5.22
    - timm: 1.0.7
    - tinycss2: 1.3.0
    - toml: 0.10.2
    - tomli: 2.0.1
    - toolz: 0.12.1
    - torch: 2.3.1
    - torchdata: 0.7.1
    - torchmetrics: 1.0.3
    - torchsummary: 1.5.1
    - torchvision: 0.18.1
    - tornado: 6.4
    - tqdm: 4.66.4
    - traitlets: 5.14.3
    - treelite: 4.1.2
    - trio: 0.25.1
    - trio-websocket: 0.11.1
    - triton: 2.3.1
    - typer: 0.12.3
    - types-python-dateutil: 2.9.0.20240316
    - typing-extensions: 4.12.2
    - typing-utils: 0.1.0
    - tzdata: 2024.1
    - uc-micro-py: 1.0.3
    - ucx-py-cu11: 0.38.0
    - ucxx-cu11: 0.38.0
    - uri-template: 1.3.0
    - urllib3: 2.0.7
    - urwid: 2.6.12
    - urwid-readline: 0.14
    - ventanamripy: 0.3
    - virtualenv: 20.26.2
    - wandb: 0.17.2
    - watchdog: 4.0.1
    - wcwidth: 0.2.13
    - webcolors: 1.13
    - webdataset: 0.2.93
    - webencodings: 0.5.1
    - websocket-client: 1.8.0
    - werkzeug: 3.0.3
    - wheel: 0.43.0
    - widgetsnbextension: 4.0.11
    - wmctrl: 0.5
    - wrapt: 1.14.1
    - wsproto: 1.2.0
    - xarray: 2024.6.0
    - xformers: 0.0.27
    - xmod: 1.8.1
    - xsdata: 24.3.1
    - xyzservices: 2024.6.0
    - yacs: 0.1.8
    - yarl: 1.9.4
    - zarr: 2.18.2
    - zict: 3.0.0
    - zipp: 3.19.1
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.11.9
    - release: 5.15.0-1066-aws
    - version: Trainer.fit() crashes if no checkpoint callback is providedย #72~20.04.1-Ubuntu SMP Thu Jul 18 10:41:27 UTC 2024

More info

Relevant existing issues

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.4.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions