-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointinglightningclipl.cli.LightningCLIpl.cli.LightningCLIver: 2.5.x
Description
Bug description
When attempting to continue training from a checkpoint using LightningCLI, an error occurs related to parsing the _class_path
key.
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
First, run python main.py --config config.yaml
until the first epoch is saved. Then try continuing with python main.py --config config.yaml --ckpt_path ...
with the path to checkpoints obtained in the first step.
# config.yaml
data:
class_path: MNISTDataModule
model:
class_path: LitAutoEncoder
# main.py
import os
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
from lightning.pytorch import cli
from torch.utils.data import random_split, DataLoader
import torch
from torchvision import transforms
class LitAutoEncoder(L.LightningModule):
"""From https://lightning.ai/docs/pytorch/stable/starter/introduction.html"""
def __init__(self, dim: int = 64):
super().__init__()
self.save_hyperparameters()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, dim), nn.ReLU(), nn.Linear(dim, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, dim), nn.ReLU(), nn.Linear(dim, 28 * 28)
)
def training_step(self, batch, batch_idx):
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=1e-3)
class MNISTDataModule(L.LightningDataModule):
"""From https://lightning.ai/docs/pytorch/stable/data/datamodule.html"""
def __init__(self, data_dir: str = "./"):
super().__init__()
self.save_hyperparameters()
self.data_dir = data_dir
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
# Assign test dataset for use in dataloader(s)
if stage == "test":
self.mnist_test = MNIST(
self.data_dir, train=False, transform=self.transform
)
if stage == "predict":
self.mnist_predict = MNIST(
self.data_dir, train=False, transform=self.transform
)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)
cli.LightningCLI()
Error messages and logs
error: Parser key "model":
Does not validate against any of the Union subtypes
Subtypes: [<class 'NoneType'>, <class 'lightning.pytorch.core.module.LightningModule'>]
Errors:
- Expected a <class 'NoneType'>
- Problem with given class_path 'lightning.LightningModule':
Validation failed: Key '_class_path' is not expected
Given value type: <class 'dict'>
Given value: {'_class_path': '__main__.LitAutoEncoder', 'dim': 64}
Parsing of ckpt_path hyperparameters failed!
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA GeForce RTX 5090
- available: True
- version: 12.8 - Lightning:
- lightning: 2.5.5
- lightning-utilities: 0.15.2
- pytorch-lightning: 2.5.5
- torch: 2.8.0+cu128
- torchmetrics: 1.8.2
- torchvision: 0.23.0+cu128 - Packages:
- aiohappyeyeballs: 2.6.1
- aiohttp: 3.12.15
- aiosignal: 1.4.0
- alembic: 1.16.5
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- anyio: 4.11.0
- async-timeout: 5.0.1
- attrs: 25.3.0
- authlib: 1.6.4
- bitsandbytes: 0.47.0
- blinker: 1.9.0
- cachetools: 5.5.2
- certifi: 2025.8.3
- cffi: 2.0.0
- charset-normalizer: 3.4.3
- click: 8.3.0
- cloudpickle: 3.1.1
- colorama: 0.4.6
- contourpy: 1.3.2
- cryptography: 45.0.7
- cycler: 0.12.1
- cyclopts: 3.24.0
- databricks-sdk: 0.67.0
- datasets: 4.1.1
- dill: 0.4.0
- dnspython: 2.8.0
- docker: 7.1.0
- docstring-parser: 0.17.0
- docutils: 0.22.2
- email-validator: 2.3.0
- exceptiongroup: 1.3.0
- fastapi: 0.117.1
- fastmcp: 2.12.3
- filelock: 3.19.1
- flask: 3.1.2
- fonttools: 4.60.0
- frozenlist: 1.7.0
- fsspec: 2025.9.0
- gitdb: 4.0.12
- gitpython: 3.1.45
- google-auth: 2.40.3
- graphene: 3.4.3
- graphql-core: 3.2.6
- graphql-relay: 3.2.0
- greenlet: 3.2.4
- h11: 0.16.0
- httpcore: 1.0.9
- httpx: 0.28.1
- httpx-sse: 0.4.1
- huggingface-hub: 0.35.1
- hydra-core: 1.3.2
- idna: 3.10
- importlib-metadata: 8.7.0
- importlib-resources: 6.5.2
- isodate: 0.7.2
- itsdangerous: 2.2.0
- jinja2: 3.1.6
- joblib: 1.5.2
- jsonargparse: 4.41.0
- jsonnet: 0.21.0
- jsonschema: 4.25.1
- jsonschema-path: 0.3.4
- jsonschema-specifications: 2025.9.1
- kiwisolver: 1.4.9
- lazy-object-proxy: 1.12.0
- lightning: 2.5.5
- lightning-utilities: 0.15.2
- mako: 1.3.10
- markdown-it-py: 4.0.0
- markupsafe: 3.0.2
- matplotlib: 3.10.6
- mcp: 1.15.0
- mdurl: 0.1.2
- mlflow: 3.4.0
- mlflow-skinny: 3.4.0
- mlflow-tracing: 3.4.0
- more-itertools: 10.8.0
- mpmath: 1.3.0
- multidict: 6.6.4
- multiprocess: 0.70.16
- networkx: 3.4.2
- numpy: 2.2.6
- omegaconf: 2.3.0
- openapi-core: 0.19.5
- openapi-pydantic: 0.5.1
- openapi-schema-validator: 0.6.3
- openapi-spec-validator: 0.7.2
- opentelemetry-api: 1.37.0
- opentelemetry-proto: 1.37.0
- opentelemetry-sdk: 1.37.0
- opentelemetry-semantic-conventions: 0.58b0
- packaging: 25.0
- pandas: 2.3.2
- parse: 1.20.2
- pathable: 0.4.4
- pillow: 11.3.0
- pip: 23.0.1
- propcache: 0.3.2
- protobuf: 6.32.1
- psutil: 7.1.0
- pyarrow: 21.0.0
- pyasn1: 0.6.1
- pyasn1-modules: 0.4.2
- pycparser: 2.23
- pydantic: 2.11.9
- pydantic-core: 2.33.2
- pydantic-settings: 2.11.0
- pygments: 2.19.2
- pyparsing: 3.2.5
- pyperclip: 1.10.0
- python-dateutil: 2.9.0.post0
- python-dotenv: 1.1.1
- python-multipart: 0.0.20
- pytorch-lightning: 2.5.5
- pytz: 2025.2
- pywin32: 311
- pyyaml: 6.0.2
- referencing: 0.36.2
- regex: 2025.9.18
- requests: 2.32.5
- rfc3339-validator: 0.1.4
- rich: 14.1.0
- rich-rst: 1.3.1
- rpds-py: 0.27.1
- rsa: 4.9.1
- safetensors: 0.6.2
- scikit-learn: 1.7.2
- scipy: 1.15.3
- setuptools: 65.5.0
- six: 1.17.0
- smmap: 5.0.2
- sniffio: 1.3.1
- sqlalchemy: 2.0.43
- sqlparse: 0.5.3
- sse-starlette: 3.0.2
- starlette: 0.48.0
- sympy: 1.14.0
- tensorboardx: 2.6.4
- threadpoolctl: 3.6.0
- tokenizers: 0.22.1
- tomli: 2.2.1
- torch: 2.8.0+cu128
- torchmetrics: 1.8.2
- torchvision: 0.23.0+cu128
- tqdm: 4.67.1
- transformers: 4.56.2
- triton-windows: 3.4.0.post20
- typeshed-client: 2.8.2
- typing-extensions: 4.15.0
- typing-inspection: 0.4.1
- tzdata: 2025.2
- urllib3: 2.5.0
- uvicorn: 0.37.0
- waitress: 3.0.2
- werkzeug: 3.1.1
- xxhash: 3.5.0
- yarl: 1.20.1
- zipp: 3.23.0 - System:
- OS: Windows
- architecture:
- 64bit
- WindowsPE
- processor: AMD64 Family 25 Model 97 Stepping 2, AuthenticAMD
- python: 3.10.11
- release: 10
- version: 10.0.19045
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointinglightningclipl.cli.LightningCLIpl.cli.LightningCLIver: 2.5.x