Skip to content
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed edgecase when `max_trials` is reached in `Tuner.scale_batch_size` ([#21187](https://github.com/Lightning-AI/pytorch-lightning/pull/21187))


- Fixed case where `LightningCLI` could not be initialized with `trainer_default` containing callbacks ([#21192](https://github.com/Lightning-AI/pytorch-lightning/pull/21192))


- Fixed missing reset when `ModelPruning` is applied with lottery ticket hypothesis ([#21191](https://github.com/Lightning-AI/pytorch-lightning/pull/21191))


---

## [2.5.5] - 2025-09-05
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]:


def _validate_callbacks_list(callbacks: list[Callback]) -> None:
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb, parent=Callback)]
seen_callbacks = set()
for callback in stateful_callbacks:
if callback.state_key in seen_callbacks:
Expand Down
26 changes: 26 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,3 +1917,29 @@ def __init__(self, main_param: int = 1):
cli = LightningCLI(MainModule, run=False, parser_kwargs={"parser_mode": "jsonnet"})

assert cli.config["model"]["main_param"] == 2


def test_lightning_cli_callback_trainer_default(cleandir):
"""Check that callbacks passed as trainer_defaults are properly instantiated."""
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(
BoringModel,
BoringDataModule,
trainer_defaults={
"logger": {
"class_path": "lightning.pytorch.loggers.TensorBoardLogger",
"init_args": {
"save_dir": ".",
"name": "demo",
},
},
"callbacks": {
"class_path": "lightning.pytorch.callbacks.ModelCheckpoint",
"init_args": {
"monitor": "val_loss",
},
},
},
run=False,
)
assert any(isinstance(c, ModelCheckpoint) for c in cli.trainer.callbacks)
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector, _validate_callbacks_list


@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
Expand Down Expand Up @@ -104,12 +104,12 @@ def test_checkpoint_callbacks_are_last(tmp_path):
]


class StatefulCallback0(Callback):
class StatefulCallbackContent0(Callback):
def state_dict(self):
return {"content0": 0}


class StatefulCallback1(Callback):
class StatefulCallbackContent1(Callback):
def __init__(self, unique=None, other=None):
self._unique = unique
self._other = other
Expand All @@ -126,9 +126,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
"""Test that all callback states get saved even if the ModelCheckpoint is not given as last and when there are
multiple callbacks of the same type."""

callback0 = StatefulCallback0()
callback1 = StatefulCallback1(unique="one")
callback2 = StatefulCallback1(unique="two", other=2)
callback0 = StatefulCallbackContent0()
callback1 = StatefulCallbackContent1(unique="one")
callback2 = StatefulCallbackContent1(unique="two", other=2)
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="all_states")
model = BoringModel()
trainer = Trainer(
Expand All @@ -147,9 +147,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
trainer.fit(model)

ckpt = torch.load(str(tmp_path / "all_states.ckpt"), weights_only=True)
state0 = ckpt["callbacks"]["StatefulCallback0"]
state1 = ckpt["callbacks"]["StatefulCallback1{'unique': 'one'}"]
state2 = ckpt["callbacks"]["StatefulCallback1{'unique': 'two'}"]
state0 = ckpt["callbacks"]["StatefulCallbackContent0"]
state1 = ckpt["callbacks"]["StatefulCallbackContent1{'unique': 'one'}"]
state2 = ckpt["callbacks"]["StatefulCallbackContent1{'unique': 'two'}"]
assert "content0" in state0
assert state0["content0"] == 0
assert "content1" in state1
Expand Down Expand Up @@ -323,3 +323,80 @@ def state_dict(self):

with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `MockCallback`"):
Trainer(callbacks=[MockCallback(), MockCallback()])


# Test with single stateful callback
class StatefulCallback(Callback):
def state_dict(self):
return {"state": 1}


# Test with multiple stateful callbacks with unique state keys
class StatefulCallback1(Callback):
@property
def state_key(self):
return "unique_key_1"

def state_dict(self):
return {"state": 1}


class StatefulCallback2(Callback):
@property
def state_key(self):
return "unique_key_2"

def state_dict(self):
return {"state": 2}


@pytest.mark.parametrize(
("callbacks"),
[
[Callback(), Callback()],
[StatefulCallback()],
[StatefulCallback1(), StatefulCallback2()],
],
)
def test_validate_callbacks_list_function(callbacks: list):
"""Test the _validate_callbacks_list function directly with various scenarios."""
_validate_callbacks_list(callbacks)


# Test with multiple stateful callbacks with same state key
class ConflictingCallback(Callback):
@property
def state_key(self):
return "same_key"

def state_dict(self):
return {"state": 1}


# Test with different types of stateful callbacks that happen to have same state key
class AnotherConflictingCallback(Callback):
@property
def state_key(self):
return "same_key" # Same key as ConflictingCallback

def state_dict(self):
return {"state": 3}


@pytest.mark.parametrize(
("callbacks", "match_msg"),
[
(
[ConflictingCallback(), ConflictingCallback()],
"Found more than one stateful callback of type `ConflictingCallback`",
),
(
[ConflictingCallback(), Callback(), ConflictingCallback()],
"Found more than one stateful callback of type `ConflictingCallback`",
),
([ConflictingCallback(), AnotherConflictingCallback()], "Found more than one stateful callback"),
],
)
def test_raising_error_validate_callbacks_list_function(callbacks: list, match_msg: str):
with pytest.raises(RuntimeError, match=match_msg):
_validate_callbacks_list(callbacks)
Loading