diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index ddc12a92e9f56..5744d5e5d80e5 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -37,8 +37,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed edgecase when `max_trials` is reached in `Tuner.scale_batch_size` ([#21187](https://github.com/Lightning-AI/pytorch-lightning/pull/21187)) +- Fixed case where `LightningCLI` could not be initialized with `trainer_default` containing callbacks ([#21192](https://github.com/Lightning-AI/pytorch-lightning/pull/21192)) + + - Fixed missing reset when `ModelPruning` is applied with lottery ticket hypothesis ([#21191](https://github.com/Lightning-AI/pytorch-lightning/pull/21191)) + --- ## [2.5.5] - 2025-09-05 diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index e95f196d9ae43..62dd49c26cc71 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -240,7 +240,7 @@ def _reorder_callbacks(callbacks: list[Callback]) -> list[Callback]: def _validate_callbacks_list(callbacks: list[Callback]) -> None: - stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)] + stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb, parent=Callback)] seen_callbacks = set() for callback in stateful_callbacks: if callback.state_key in seen_callbacks: diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 248852f4cf1f3..70aaac32d9661 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -1917,3 +1917,29 @@ def __init__(self, main_param: int = 1): cli = LightningCLI(MainModule, run=False, parser_kwargs={"parser_mode": "jsonnet"}) assert cli.config["model"]["main_param"] == 2 + + +def test_lightning_cli_callback_trainer_default(cleandir): + """Check that callbacks passed as trainer_defaults are properly instantiated.""" + with mock.patch("sys.argv", ["any.py"]): + cli = LightningCLI( + BoringModel, + BoringDataModule, + trainer_defaults={ + "logger": { + "class_path": "lightning.pytorch.loggers.TensorBoardLogger", + "init_args": { + "save_dir": ".", + "name": "demo", + }, + }, + "callbacks": { + "class_path": "lightning.pytorch.callbacks.ModelCheckpoint", + "init_args": { + "monitor": "val_loss", + }, + }, + }, + run=False, + ) + assert any(isinstance(c, ModelCheckpoint) for c in cli.trainer.callbacks) diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index 54fbd065fa919..bb8c365bb684c 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -32,7 +32,7 @@ ) from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector +from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector, _validate_callbacks_list @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) @@ -104,12 +104,12 @@ def test_checkpoint_callbacks_are_last(tmp_path): ] -class StatefulCallback0(Callback): +class StatefulCallbackContent0(Callback): def state_dict(self): return {"content0": 0} -class StatefulCallback1(Callback): +class StatefulCallbackContent1(Callback): def __init__(self, unique=None, other=None): self._unique = unique self._other = other @@ -126,9 +126,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path): """Test that all callback states get saved even if the ModelCheckpoint is not given as last and when there are multiple callbacks of the same type.""" - callback0 = StatefulCallback0() - callback1 = StatefulCallback1(unique="one") - callback2 = StatefulCallback1(unique="two", other=2) + callback0 = StatefulCallbackContent0() + callback1 = StatefulCallbackContent1(unique="one") + callback2 = StatefulCallbackContent1(unique="two", other=2) checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="all_states") model = BoringModel() trainer = Trainer( @@ -147,9 +147,9 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path): trainer.fit(model) ckpt = torch.load(str(tmp_path / "all_states.ckpt"), weights_only=True) - state0 = ckpt["callbacks"]["StatefulCallback0"] - state1 = ckpt["callbacks"]["StatefulCallback1{'unique': 'one'}"] - state2 = ckpt["callbacks"]["StatefulCallback1{'unique': 'two'}"] + state0 = ckpt["callbacks"]["StatefulCallbackContent0"] + state1 = ckpt["callbacks"]["StatefulCallbackContent1{'unique': 'one'}"] + state2 = ckpt["callbacks"]["StatefulCallbackContent1{'unique': 'two'}"] assert "content0" in state0 assert state0["content0"] == 0 assert "content1" in state1 @@ -323,3 +323,80 @@ def state_dict(self): with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `MockCallback`"): Trainer(callbacks=[MockCallback(), MockCallback()]) + + +# Test with single stateful callback +class StatefulCallback(Callback): + def state_dict(self): + return {"state": 1} + + +# Test with multiple stateful callbacks with unique state keys +class StatefulCallback1(Callback): + @property + def state_key(self): + return "unique_key_1" + + def state_dict(self): + return {"state": 1} + + +class StatefulCallback2(Callback): + @property + def state_key(self): + return "unique_key_2" + + def state_dict(self): + return {"state": 2} + + +@pytest.mark.parametrize( + ("callbacks"), + [ + [Callback(), Callback()], + [StatefulCallback()], + [StatefulCallback1(), StatefulCallback2()], + ], +) +def test_validate_callbacks_list_function(callbacks: list): + """Test the _validate_callbacks_list function directly with various scenarios.""" + _validate_callbacks_list(callbacks) + + +# Test with multiple stateful callbacks with same state key +class ConflictingCallback(Callback): + @property + def state_key(self): + return "same_key" + + def state_dict(self): + return {"state": 1} + + +# Test with different types of stateful callbacks that happen to have same state key +class AnotherConflictingCallback(Callback): + @property + def state_key(self): + return "same_key" # Same key as ConflictingCallback + + def state_dict(self): + return {"state": 3} + + +@pytest.mark.parametrize( + ("callbacks", "match_msg"), + [ + ( + [ConflictingCallback(), ConflictingCallback()], + "Found more than one stateful callback of type `ConflictingCallback`", + ), + ( + [ConflictingCallback(), Callback(), ConflictingCallback()], + "Found more than one stateful callback of type `ConflictingCallback`", + ), + ([ConflictingCallback(), AnotherConflictingCallback()], "Found more than one stateful callback"), + ], +) +def test_raising_error_validate_callbacks_list_function(callbacks: list, match_msg: str): + with pytest.raises(RuntimeError, match=match_msg): + _validate_callbacks_list(callbacks)