-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
I implemented a custom PyTorch Lightning Callback and passed it to the Trainer via LightningCLI. However, I encountered an error where my callback was not recognized as an instance of Callback inside _validate_callbacks_list().
After debugging, I found that the issue stems from _check_mixed_imports() in lightning/pytorch/utilities/model_helpers.py, which does not correctly detect mixed imports when using lightning.pytorch vs pytorch_lightning.
What version are you seeing the problem on?
v2.5
How to reproduce the bug
import torch
from lightning.pytorch.demos.boring_classes import BoringModel as BoringLightningModule
from lightning.pytorch.demos.dummy_data import DummyDataModule as DummyDatamodule
from lightning.pytorch.cli import LightningCLI # New import syntax
from pytorch_lightning.callbacks import Callback # Old import syntax
class MyCallback(Callback):
def __init__(self):
super().__init__()
# More code...
class MyCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
super().add_arguments_to_parser(parser)
parser.add_optimizer_args(torch.optim.SGD)
def cli_main():
return MyCLI(
BoringLightningModule,
DummyDatamodule,
trainer_defaults={"callbacks": [MyCallback]}
)
if __name__ == "__main__":
torch.set_float32_matmul_precision("medium")
cli = cli_main()Error messages and logs
File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/callback_connector.py", line 228, in _validate_callbacks_list
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/utilities/model_helpers.py", line 42, in is_overridden
raise ValueError("Expected a parent")
ValueError: Expected a parentThis suggests that the callback is not recognized as a valid subclass of Callback inside _validate_callbacks_list().
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0): 2.5.0
#- PyTorch Version (e.g., 2.5): 2.6
#- Python version (e.g., 3.12): 3.11
#- OS (e.g., Linux): Linux Ubuntu
#- CUDA/cuDNN version: 12.6
#- How you installed Lightning(`conda`, `pip`, source): `pip`
More info
I traced the issue to _check_mixed_imports() in lightning/pytorch/utilities/model_helpers.py, which should detect mixed imports (e.g., pytorch_lightning vs. lightning.pytorch):
def _check_mixed_imports(instance: object) -> None:
old, new = "pytorch_" + "lightning", "lightning." + "pytorch"
klass = type(instance)
module = klass.__module__
if module.startswith(old) and __name__.startswith(new):
pass
elif module.startswith(new) and __name__.startswith(old):
old, new = new, old
else:
return
raise TypeError(
f"You passed a `{old}` object ({klass.__qualname__}) to a `{new}` Trainer. "
"Please switch to a single import style."
)✅ Expected Behavior: This function should catch mismatched imports and raise a TypeError.
❌ Observed Behavior: The function does not detect the mixed import case, leading to a silent error.
🔧 Proposed Fix
Modify _check_mixed_imports() to actually catch mixed imports by ensuring it checks both module and type(instance) correctly:
def _check_mixed_imports(instance: object) -> None:
old, new = "pytorch_" + "lightning", "lightning." + "pytorch"
klass = type(instance)
module = klass.__module__
if module.startswith(old) and __name__.startswith(new) or module.startswith(new) and __name__.startswith(old):
raise TypeError(
f"You passed a `{old}` object ({klass.__qualname__}) to a `{new}` Trainer. "
"Please switch to a single import style."
)