Skip to content
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed edgecase when `max_trials` is reached in `Tuner.scale_batch_size` ([#21187](https://github.com/Lightning-AI/pytorch-lightning/pull/21187))


- Fixed case where `LightningCLI` could not be initialized with `trainer_default` containing callbacks ([#21192](https://github.com/Lightning-AI/pytorch-lightning/pull/21192))

---

## [2.5.5] - 2025-09-05
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]:


def _validate_callbacks_list(callbacks: list[Callback]) -> None:
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb, parent=Callback)]
seen_callbacks = set()
for callback in stateful_callbacks:
if callback.state_key in seen_callbacks:
Expand Down
26 changes: 26 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,3 +1917,29 @@ def __init__(self, main_param: int = 1):
cli = LightningCLI(MainModule, run=False, parser_kwargs={"parser_mode": "jsonnet"})

assert cli.config["model"]["main_param"] == 2


def test_lightning_cli_callback_trainer_default(cleandir):
"""Check that callbacks passed as trainer_defaults are properly instantiated."""
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(
BoringModel,
BoringDataModule,
trainer_defaults={
"logger": {
"class_path": "lightning.pytorch.loggers.TensorBoardLogger",
"init_args": {
"save_dir": ".",
"name": "demo",
},
},
"callbacks": {
"class_path": "lightning.pytorch.callbacks.ModelCheckpoint",
"init_args": {
"monitor": "val_loss",
},
},
},
run=False,
)
assert any(isinstance(c, ModelCheckpoint) for c in cli.trainer.callbacks)
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector, _validate_callbacks_list


@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
Expand Down Expand Up @@ -323,3 +323,74 @@ def state_dict(self):

with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `MockCallback`"):
Trainer(callbacks=[MockCallback(), MockCallback()])


def test_validate_callbacks_list_function():
"""Test the _validate_callbacks_list function directly with various scenarios."""

# Test with non-stateful callbacks
callback1 = Callback()
callback2 = Callback()
_validate_callbacks_list([callback1, callback2])

# Test with single stateful callback
class StatefulCallback(Callback):
def state_dict(self):
return {"state": 1}

stateful_cb = StatefulCallback()
_validate_callbacks_list([stateful_cb])

# Test with multiple stateful callbacks with unique state keys
class StatefulCallback1(Callback):
@property
def state_key(self):
return "unique_key_1"

def state_dict(self):
return {"state": 1}

class StatefulCallback2(Callback):
@property
def state_key(self):
return "unique_key_2"

def state_dict(self):
return {"state": 2}

stateful_cb1 = StatefulCallback1()
stateful_cb2 = StatefulCallback2()
_validate_callbacks_list([stateful_cb1, stateful_cb2])

# Test with multiple stateful callbacks with same state key
class ConflictingCallback(Callback):
@property
def state_key(self):
return "same_key"

def state_dict(self):
return {"state": 1}

conflicting_cb1 = ConflictingCallback()
conflicting_cb2 = ConflictingCallback()

with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `ConflictingCallback`"):
_validate_callbacks_list([conflicting_cb1, conflicting_cb2])

# Test with mix of stateful and non-stateful callbacks where stateful ones conflict
with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `ConflictingCallback`"):
_validate_callbacks_list([callback1, conflicting_cb1, callback2, conflicting_cb2])

# Test with different types of stateful callbacks that happen to have same state key
class AnotherConflictingCallback(Callback):
@property
def state_key(self):
return "same_key" # Same key as ConflictingCallback

def state_dict(self):
return {"state": 3}

another_conflicting_cb = AnotherConflictingCallback()

with pytest.raises(RuntimeError, match="Found more than one stateful callback"):
_validate_callbacks_list([conflicting_cb1, another_conflicting_cb])
Loading