-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.4.x
Description
Bug description
The minimal example below throws the error RuntimeError: Error while merging hparams: the keys ['_class_path'] are present in both the LightningModule's and LightningDataModule's hparams but have different values.
I though this was supposed to work. Would really appreciate workaround tips (that also work with checkpointing) or a fix.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
Run the code below using e.g. python main.py --config config.yaml
# config.yaml
data:
class_path: MNISTDataModule
model:
class_path: LitAutoEncoder
# main.py
import os
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
from lightning.pytorch import cli
from torch.utils.data import random_split, DataLoader
import torch
from torchvision import transforms
class LitAutoEncoder(L.LightningModule):
"""From https://lightning.ai/docs/pytorch/stable/starter/introduction.html"""
def __init__(self, dim: int = 64):
super().__init__()
self.save_hyperparameters()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, dim), nn.ReLU(), nn.Linear(dim, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, dim), nn.ReLU(), nn.Linear(dim, 28 * 28)
)
def training_step(self, batch, batch_idx):
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=1e-3)
class MNISTDataModule(L.LightningDataModule):
"""From https://lightning.ai/docs/pytorch/stable/data/datamodule.html"""
def __init__(self, data_dir: str = "./"):
super().__init__()
self.save_hyperparameters()
self.data_dir = data_dir
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
# Assign test dataset for use in dataloader(s)
if stage == "test":
self.mnist_test = MNIST(
self.data_dir, train=False, transform=self.transform
)
if stage == "predict":
self.mnist_predict = MNIST(
self.data_dir, train=False, transform=self.transform
)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)
cli.LightningCLI()
Error messages and logs
RuntimeError: Error while merging hparams: the keys ['_class_path'] are present in both the LightningModule's and LightningDataModule's hparams but have different values.
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA A10G
- available: True
- version: 12.1 - Lightning:
- lightning: 2.4.0
- lightning-utilities: 0.11.4
- pytorch-lightning: 2.3.3
- torch: 2.3.1
- torchdata: 0.7.1
- torchmetrics: 1.0.3
- torchsummary: 1.5.1
- torchvision: 0.18.1 - Packages:
- absl-py: 2.1.0
- affine: 2.4.0
- aiobotocore: 2.13.0
- aiohttp: 3.9.5
- aioitertools: 0.11.0
- aiosignal: 1.3.1
- albucore: 0.0.12
- albumentations: 1.4.11
- altair: 5.3.0
- annotated-types: 0.7.0
- ansi2html: 1.9.1
- ansicolors: 1.1.8
- antlr4-python3-runtime: 4.9.3
- anyio: 4.4.0
- appdirs: 1.4.4
- argon2-cffi: 23.1.0
- argon2-cffi-bindings: 21.2.0
- arrow: 1.3.0
- asciitree: 0.3.3
- assertpy: 1.1
- asttokens: 2.4.1
- astunparse: 1.6.3
- async-lru: 2.0.4
- attrs: 23.2.0
- automation-api-gateway-client: 12.0.165
- az-annotation-io: 0.18.8
- az-cp-aws-utils: 1.2.5
- az-cp-datadictionary-definitions: 0.2024.119.post1
- az-cp-drclib: 0.1.0
- az-cp-holoviews-compressed-rgb: 0.0.1rc3
- az-cp-imagekit-ventana-bif: 0.2.2
- az-cp-logging: 0.15.0
- az-cp-ooportal: 1.4.7
- az-cp-pathviz: 0.0.1rc7
- az-cp-pita: 1.10.1
- az-cp-predictino-container: 3.10.3
- az-drc2polygons: 2.2.2
- az-git-utils: 0.14.0
- babel: 2.15.0
- beautifulsoup4: 4.12.3
- bleach: 6.1.0
- blessed: 1.20.0
- blinker: 1.8.2
- bokeh: 3.4.1
- boto3: 1.34.106
- botocore: 1.34.106
- bpython: 0.24
- braceexpand: 0.1.7
- cachetools: 5.3.3
- cerberus: 1.3.5
- certifi: 2024.2.2
- cffi: 1.16.0
- cfgv: 3.4.0
- charset-normalizer: 3.3.2
- click: 8.1.7
- click-plugins: 1.1.1
- cligj: 0.7.2
- cloudpickle: 3.0.0
- cmake: 3.25.0
- colorama: 0.4.6
- colorcet: 3.1.0
- comm: 0.2.2
- contourpy: 1.2.1
- coreg-tools: 0.3.0
- cubinlinker-cu11: 0.3.0.post2
- cuda-python: 11.8.3
- cudf-cu11: 24.6.1
- cuml-cu11: 24.6.1
- cupy-cuda11x: 13.2.0
- curtsies: 0.4.2
- cwcwidth: 0.1.9
- cycler: 0.12.1
- dash: 2.17.0
- dash-core-components: 2.0.0
- dash-html-components: 2.0.0
- dash-table: 5.0.0
- dask: 2024.5.1
- dask-cuda: 24.6.0
- dask-cudf-cu11: 24.6.1
- dask-expr: 1.1.1
- datashader: 0.16.2
- debugpy: 1.8.1
- decorator: 5.1.1
- definiens-autocli: 4.5.3
- definiens-ia-algorithms: 3.39.1
- definiens-imagekit: 5.6.6
- definiens-parallel: 1.3.0
- defusedxml: 0.7.1
- deprecated: 1.2.14
- distlib: 0.3.8
- distributed: 2024.5.1
- distributed-ucxx-cu11: 0.38.0
- dm-tree: 0.1.8
- dnspython: 2.6.1
- docker-pycreds: 0.4.0
- docstring-parser: 0.16
- ec2-metadata: 2.13.0
- editor: 1.6.6
- entrypoints: 0.4
- eval-type-backport: 0.2.0
- executing: 2.0.1
- fancycompleter: 0.9.1
- fasteners: 0.19
- fastjsonschema: 2.19.1
- fastrlock: 0.8.2
- filelock: 3.14.0
- fiona: 1.9.6
- flask: 3.0.3
- flatbuffers: 24.3.25
- fonttools: 4.52.4
- fqdn: 1.5.1
- frozendict: 2.4.4
- frozenlist: 1.4.1
- fsspec: 2024.5.0
- fvcore: 0.1.5.post20221221
- gast: 0.6.0
- geopandas: 1.0.1
- geopolars: 0.1.0a4
- gitdb: 4.0.11
- gitpython: 3.1.43
- google-auth: 2.32.0
- google-auth-oauthlib: 1.0.0
- google-pasta: 0.2.0
- greenlet: 3.0.3
- grpcio: 1.64.1
- h11: 0.14.0
- h5py: 3.11.0
- holoviews: 1.19.0
- httpcore: 1.0.5
- httpx: 0.27.0
- huggingface-hub: 0.24.0
- hvplot: 0.10.0
- identify: 2.5.36
- idna: 3.7
- imagecodecs: 2024.1.1
- imageio: 2.34.1
- importlib-metadata: 7.1.0
- iniconfig: 2.0.0
- inquirer: 3.2.4
- ioda: 0.19.1
- ioda-readout-service-client: 0.0.5
- ioda-result-service-client: 0.6.1
- iopath: 0.1.10
- ipp: 2021.4.0
- ipykernel: 6.29.4
- ipython: 8.24.0
- ipytree: 0.2.2
- ipywidgets: 8.1.3
- isoduration: 20.11.0
- itables: 2.1.0
- itsdangerous: 2.2.0
- jax: 0.4.30
- jaxlib: 0.4.30
- jedi: 0.19.1
- jinja2: 3.1.4
- jmespath: 1.0.1
- joblib: 1.4.2
- json5: 0.9.25
- jsonargparse: 4.29.0
- jsonpointer: 2.4
- jsonschema: 4.22.0
- jsonschema-specifications: 2023.12.1
- jupyter: 1.0.0
- jupyter-bokeh: 4.0.5
- jupyter-client: 8.6.2
- jupyter-console: 6.6.3
- jupyter-core: 5.7.2
- jupyter-dash: 0.4.2
- jupyter-events: 0.10.0
- jupyter-lsp: 2.2.5
- jupyter-server: 2.14.1
- jupyter-server-terminals: 0.5.3
- jupyterlab: 4.2.1
- jupyterlab-pygments: 0.3.0
- jupyterlab-server: 2.27.2
- jupyterlab-vim: 4.1.3
- jupyterlab-widgets: 3.0.11
- kaleido: 0.2.1
- keras: 2.12.0
- kiwisolver: 1.4.5
- lazy-loader: 0.4
- leb128: 1.0.7
- libclang: 18.1.1
- libucx-cu11: 1.15.0.post1
- lightning: 2.4.0
- lightning-utilities: 0.11.4
- linkify-it-py: 2.0.3
- lit: 15.0.7
- litdata: 0.2.16
- llvmlite: 0.42.0
- locket: 1.0.0
- loguru: 0.7.2
- lxml: 4.9.4
- markdown: 3.6
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- matplotlib: 3.9.0
- matplotlib-inline: 0.1.7
- mdit-py-plugins: 0.4.1
- mdurl: 0.1.2
- mistune: 3.0.2
- ml-dtypes: 0.4.0
- mpmath: 1.3.0
- msgpack: 1.0.8
- multidict: 6.0.5
- multipledispatch: 1.0.0
- namex: 0.0.8
- nbclient: 0.10.0
- nbconvert: 7.16.4
- nbformat: 5.10.4
- nest-asyncio: 1.6.0
- networkx: 3.3
- newrelic: 8.11.0
- nodeenv: 1.9.0
- notebook: 7.2.0
- notebook-cgebbe: 0.1.0
- notebook-shim: 0.2.4
- nrai-wrapper: 2.19.339839
- numba: 0.59.1
- numcodecs: 0.12.1
- numpy: 1.26.4
- nvidia-cublas-cu12: 12.1.3.1
- nvidia-cuda-cupti-cu12: 12.1.105
- nvidia-cuda-nvrtc-cu12: 12.1.105
- nvidia-cuda-runtime-cu12: 12.1.105
- nvidia-cudnn-cu12: 8.9.2.26
- nvidia-cufft-cu12: 11.0.2.54
- nvidia-curand-cu12: 10.3.2.106
- nvidia-cusolver-cu12: 11.4.5.107
- nvidia-cusparse-cu12: 12.1.0.106
- nvidia-dali-cuda120: 1.39.0
- nvidia-nccl-cu12: 2.20.5
- nvidia-nvimgcodec-cu12: 0.2.0.7
- nvidia-nvjitlink-cu12: 12.5.82
- nvidia-nvtx-cu12: 12.1.105
- nvtx: 0.2.10
- oauthlib: 3.2.2
- ome-types: 0.5.1.post1
- omegaconf: 2.3.0
- opencv-python-headless: 4.9.0.80
- openslide-python: 1.3.1
- opt-einsum: 3.3.0
- optree: 0.12.1
- outcome: 1.3.0.post0
- overrides: 7.7.0
- packaging: 24.0
- pandas: 2.1.4
- pandocfilters: 1.5.1
- panel: 1.4.4
- papermill: 2.6.0
- param: 2.1.0
- parso: 0.8.4
- partd: 1.4.2
- pdbpp: 0.10.3
- pexpect: 4.9.0
- pillow: 10.3.0
- pip: 24.1.2
- pivottablejs: 0.9.0
- platformdirs: 4.2.2
- plotly: 5.22.0
- pluggy: 1.5.0
- polars: 0.20.30
- portal-test-utils: 8.100.0
- portalocker: 2.10.1
- pre-commit: 3.7.1
- prometheus-client: 0.20.0
- prompt-toolkit: 3.0.45
- protobuf: 3.20.0
- psutil: 5.9.8
- psycopg2-binary: 2.9.9
- ptpython: 3.0.27
- ptxcompiler-cu11: 0.8.1.post1
- ptyprocess: 0.7.0
- pudb: 2024.1
- pure-eval: 0.2.2
- pyarrow: 16.1.0
- pyasn1: 0.6.0
- pyasn1-modules: 0.4.0
- pycparser: 2.22
- pyct: 0.5.0
- pydantic: 2.7.3
- pydantic-compat: 0.1.2
- pydantic-core: 2.18.4
- pydeck: 0.9.1
- pygments: 2.18.0
- pylibraft-cu11: 24.6.0
- pymongo: 4.8.0
- pynvml: 11.4.1
- pyogrio: 0.9.0
- pyotp: 2.9.0
- pyparsing: 3.1.2
- pypeln: 0.4.9
- pyportal: 2.33.2
- pyproj: 3.6.1
- pyrepl: 0.9.0
- pysocks: 1.7.1
- pytest: 8.2.2
- python-dateutil: 2.9.0.post0
- python-dotenv: 1.0.1
- python-gitlab: 3.15.0
- python-json-logger: 2.0.7
- python-magic: 0.4.27
- python-on-whales: 0.71.0
- pytorch-lightning: 2.3.3
- pytz: 2024.1
- pyviz-comms: 3.0.2
- pyxdg: 0.28
- pyyaml: 6.0.1
- pyzmq: 26.0.3
- qtconsole: 5.5.2
- qtpy: 2.4.1
- raft-dask-cu11: 24.6.0
- rai: 2.4.1
- rapids-dask-dependency: 24.6.0
- rasterio: 1.3.10
- ray: 2.31.0
- rdi: 2.19.339839
- readchar: 4.1.0
- referencing: 0.35.1
- requests: 2.32.2
- requests-oauthlib: 2.0.0
- requests-toolbelt: 1.0.0
- retrying: 1.3.4
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- rich: 13.7.1
- rmm-cu11: 24.6.0
- rpds-py: 0.18.1
- rsa: 4.9
- rtree: 1.2.0
- ruamel.yaml: 0.18.6
- ruamel.yaml.clib: 0.2.8
- ruff: 0.5.1
- runs: 1.2.2
- s3cmd: 2.4.0
- s3fs: 2024.5.0
- s3transfer: 0.10.1
- safetensors: 0.4.3
- scikit-image: 0.23.2
- scikit-learn: 1.5.1
- scipy: 1.13.1
- seaborn: 0.13.2
- segment-anything: 1.0
- selenium: 4.22.0
- semantic-segmentation: 0.1.0
- send2trash: 1.8.3
- sentry-sdk: 2.6.0
- setproctitle: 1.3.3
- setuptools: 63.4.3
- shapely: 2.0.4
- shellingham: 1.5.4
- six: 1.16.0
- smmap: 5.0.1
- sniffio: 1.3.1
- snuggs: 1.4.7
- sortedcontainers: 2.4.0
- soupsieve: 2.5
- spatialpandas: 0.4.10
- sqlalchemy: 2.0.31
- stack-data: 0.6.3
- stopit: 1.1.2
- streamlit: 1.36.0
- submitit: 1.5.1
- sympy: 1.12
- tabulate: 0.9.0
- tblib: 3.0.0
- tenacity: 8.3.0
- tensorboard: 2.12.3
- tensorboard-data-server: 0.7.2
- tensorflow: 2.12.1
- tensorflow-estimator: 2.12.0
- tensorflow-io: 0.37.1
- tensorflow-io-gcs-filesystem: 0.37.1
- termcolor: 2.4.0
- terminado: 0.18.1
- threadpoolctl: 3.5.0
- tifffile: 2024.5.22
- timm: 1.0.7
- tinycss2: 1.3.0
- toml: 0.10.2
- tomli: 2.0.1
- toolz: 0.12.1
- torch: 2.3.1
- torchdata: 0.7.1
- torchmetrics: 1.0.3
- torchsummary: 1.5.1
- torchvision: 0.18.1
- tornado: 6.4
- tqdm: 4.66.4
- traitlets: 5.14.3
- treelite: 4.1.2
- trio: 0.25.1
- trio-websocket: 0.11.1
- triton: 2.3.1
- typer: 0.12.3
- types-python-dateutil: 2.9.0.20240316
- typing-extensions: 4.12.2
- typing-utils: 0.1.0
- tzdata: 2024.1
- uc-micro-py: 1.0.3
- ucx-py-cu11: 0.38.0
- ucxx-cu11: 0.38.0
- uri-template: 1.3.0
- urllib3: 2.0.7
- urwid: 2.6.12
- urwid-readline: 0.14
- ventanamripy: 0.3
- virtualenv: 20.26.2
- wandb: 0.17.2
- watchdog: 4.0.1
- wcwidth: 0.2.13
- webcolors: 1.13
- webdataset: 0.2.93
- webencodings: 0.5.1
- websocket-client: 1.8.0
- werkzeug: 3.0.3
- wheel: 0.43.0
- widgetsnbextension: 4.0.11
- wmctrl: 0.5
- wrapt: 1.14.1
- wsproto: 1.2.0
- xarray: 2024.6.0
- xformers: 0.0.27
- xmod: 1.8.1
- xsdata: 24.3.1
- xyzservices: 2024.6.0
- yacs: 0.1.8
- yarl: 1.9.4
- zarr: 2.18.2
- zict: 3.0.0
- zipp: 3.19.1 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.11.9
- release: 5.15.0-1066-aws
- version: Trainer.fit() crashes if no checkpoint callback is providedย #72~20.04.1-Ubuntu SMP Thu Jul 18 10:41:27 UTC 2024
More info
Relevant existing issues
al-jshen and tom-hehir
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.4.x