-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
I implemented a custom PyTorch Lightning Callback
and passed it to the Trainer via LightningCLI.
However, I encountered an error where my callback was not recognized as an instance of Callback
inside _validate_callbacks_list()
.
After debugging, I found that the issue stems from _check_mixed_imports()
in lightning/pytorch/utilities/model_helpers.py
, which does not correctly detect mixed imports when using lightning.pytorch
vs pytorch_lightning
.
The Studio Template can be found here: https://lightning.ai/acme-ai/studios/lightningcli-fails-to-recognize-custom-callback-due-to-mixed-import-styles~01hpygektnhxpn0tgm1jk6z485
What version are you seeing the problem on?
v2.5
How to reproduce the bug
import torch
from lightning.pytorch.demos.boring_classes import BoringModel as BoringLightningModule
from lightning.pytorch.demos.dummy_data import DummyDataModule as DummyDatamodule
from lightning.pytorch.cli import LightningCLI # New import syntax
from pytorch_lightning.callbacks import Callback # Old import syntax
class MyCallback(Callback):
def __init__(self):
super().__init__()
# More code...
class MyCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
super().add_arguments_to_parser(parser)
parser.add_optimizer_args(torch.optim.SGD)
def cli_main():
return MyCLI(
BoringLightningModule,
DummyDatamodule,
trainer_defaults={"callbacks": [MyCallback]}
)
if __name__ == "__main__":
torch.set_float32_matmul_precision("medium")
cli = cli_main()
Error messages and logs
File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/callback_connector.py", line 228, in _validate_callbacks_list
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/utilities/model_helpers.py", line 42, in is_overridden
raise ValueError("Expected a parent")
ValueError: Expected a parent```
### Environment
<details>
<summary>Current environment</summary>
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(conda
, pip
, source):
</details>
### More info
_No response_