Skip to content

Commit e4c4e9a

Browse files
SkafteNickipre-commit-ci[bot]Borda
authored
Fix lightning cli crashing when trainer defaults contain callback (#21192)
* add parent * add unittesting of function * add testing in cli * add changelog * split tests * linter --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka B <[email protected]>
1 parent d73c50d commit e4c4e9a

File tree

4 files changed

+117
-10
lines changed

4 files changed

+117
-10
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3737
- Fixed edgecase when `max_trials` is reached in `Tuner.scale_batch_size` ([#21187](https://github.com/Lightning-AI/pytorch-lightning/pull/21187))
3838

3939

40+
- Fixed case where `LightningCLI` could not be initialized with `trainer_default` containing callbacks ([#21192](https://github.com/Lightning-AI/pytorch-lightning/pull/21192))
41+
42+
4043
- Fixed missing reset when `ModelPruning` is applied with lottery ticket hypothesis ([#21191](https://github.com/Lightning-AI/pytorch-lightning/pull/21191))
4144

45+
4246
---
4347

4448
## [2.5.5] - 2025-09-05

src/lightning/pytorch/trainer/connectors/callback_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]:
240240

241241

242242
def _validate_callbacks_list(callbacks: list[Callback]) -> None:
243-
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
243+
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb, parent=Callback)]
244244
seen_callbacks = set()
245245
for callback in stateful_callbacks:
246246
if callback.state_key in seen_callbacks:

tests/tests_pytorch/test_cli.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,3 +1917,29 @@ def __init__(self, main_param: int = 1):
19171917
cli = LightningCLI(MainModule, run=False, parser_kwargs={"parser_mode": "jsonnet"})
19181918

19191919
assert cli.config["model"]["main_param"] == 2
1920+
1921+
1922+
def test_lightning_cli_callback_trainer_default(cleandir):
1923+
"""Check that callbacks passed as trainer_defaults are properly instantiated."""
1924+
with mock.patch("sys.argv", ["any.py"]):
1925+
cli = LightningCLI(
1926+
BoringModel,
1927+
BoringDataModule,
1928+
trainer_defaults={
1929+
"logger": {
1930+
"class_path": "lightning.pytorch.loggers.TensorBoardLogger",
1931+
"init_args": {
1932+
"save_dir": ".",
1933+
"name": "demo",
1934+
},
1935+
},
1936+
"callbacks": {
1937+
"class_path": "lightning.pytorch.callbacks.ModelCheckpoint",
1938+
"init_args": {
1939+
"monitor": "val_loss",
1940+
},
1941+
},
1942+
},
1943+
run=False,
1944+
)
1945+
assert any(isinstance(c, ModelCheckpoint) for c in cli.trainer.callbacks)

tests/tests_pytorch/trainer/connectors/test_callback_connector.py

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
3434
from lightning.pytorch.demos.boring_classes import BoringModel
35-
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
35+
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector, _validate_callbacks_list
3636

3737

3838
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
@@ -104,12 +104,12 @@ def test_checkpoint_callbacks_are_last(tmp_path):
104104
]
105105

106106

107-
class StatefulCallback0(Callback):
107+
class StatefulCallbackContent0(Callback):
108108
def state_dict(self):
109109
return {"content0": 0}
110110

111111

112-
class StatefulCallback1(Callback):
112+
class StatefulCallbackContent1(Callback):
113113
def __init__(self, unique=None, other=None):
114114
self._unique = unique
115115
self._other = other
@@ -126,9 +126,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
126126
"""Test that all callback states get saved even if the ModelCheckpoint is not given as last and when there are
127127
multiple callbacks of the same type."""
128128

129-
callback0 = StatefulCallback0()
130-
callback1 = StatefulCallback1(unique="one")
131-
callback2 = StatefulCallback1(unique="two", other=2)
129+
callback0 = StatefulCallbackContent0()
130+
callback1 = StatefulCallbackContent1(unique="one")
131+
callback2 = StatefulCallbackContent1(unique="two", other=2)
132132
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="all_states")
133133
model = BoringModel()
134134
trainer = Trainer(
@@ -147,9 +147,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
147147
trainer.fit(model)
148148

149149
ckpt = torch.load(str(tmp_path / "all_states.ckpt"), weights_only=True)
150-
state0 = ckpt["callbacks"]["StatefulCallback0"]
151-
state1 = ckpt["callbacks"]["StatefulCallback1{'unique': 'one'}"]
152-
state2 = ckpt["callbacks"]["StatefulCallback1{'unique': 'two'}"]
150+
state0 = ckpt["callbacks"]["StatefulCallbackContent0"]
151+
state1 = ckpt["callbacks"]["StatefulCallbackContent1{'unique': 'one'}"]
152+
state2 = ckpt["callbacks"]["StatefulCallbackContent1{'unique': 'two'}"]
153153
assert "content0" in state0
154154
assert state0["content0"] == 0
155155
assert "content1" in state1
@@ -323,3 +323,80 @@ def state_dict(self):
323323

324324
with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `MockCallback`"):
325325
Trainer(callbacks=[MockCallback(), MockCallback()])
326+
327+
328+
# Test with single stateful callback
329+
class StatefulCallback(Callback):
330+
def state_dict(self):
331+
return {"state": 1}
332+
333+
334+
# Test with multiple stateful callbacks with unique state keys
335+
class StatefulCallback1(Callback):
336+
@property
337+
def state_key(self):
338+
return "unique_key_1"
339+
340+
def state_dict(self):
341+
return {"state": 1}
342+
343+
344+
class StatefulCallback2(Callback):
345+
@property
346+
def state_key(self):
347+
return "unique_key_2"
348+
349+
def state_dict(self):
350+
return {"state": 2}
351+
352+
353+
@pytest.mark.parametrize(
354+
("callbacks"),
355+
[
356+
[Callback(), Callback()],
357+
[StatefulCallback()],
358+
[StatefulCallback1(), StatefulCallback2()],
359+
],
360+
)
361+
def test_validate_callbacks_list_function(callbacks: list):
362+
"""Test the _validate_callbacks_list function directly with various scenarios."""
363+
_validate_callbacks_list(callbacks)
364+
365+
366+
# Test with multiple stateful callbacks with same state key
367+
class ConflictingCallback(Callback):
368+
@property
369+
def state_key(self):
370+
return "same_key"
371+
372+
def state_dict(self):
373+
return {"state": 1}
374+
375+
376+
# Test with different types of stateful callbacks that happen to have same state key
377+
class AnotherConflictingCallback(Callback):
378+
@property
379+
def state_key(self):
380+
return "same_key" # Same key as ConflictingCallback
381+
382+
def state_dict(self):
383+
return {"state": 3}
384+
385+
386+
@pytest.mark.parametrize(
387+
("callbacks", "match_msg"),
388+
[
389+
(
390+
[ConflictingCallback(), ConflictingCallback()],
391+
"Found more than one stateful callback of type `ConflictingCallback`",
392+
),
393+
(
394+
[ConflictingCallback(), Callback(), ConflictingCallback()],
395+
"Found more than one stateful callback of type `ConflictingCallback`",
396+
),
397+
([ConflictingCallback(), AnotherConflictingCallback()], "Found more than one stateful callback"),
398+
],
399+
)
400+
def test_raising_error_validate_callbacks_list_function(callbacks: list, match_msg: str):
401+
with pytest.raises(RuntimeError, match=match_msg):
402+
_validate_callbacks_list(callbacks)

0 commit comments

Comments
 (0)