diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 65aef1ea3b306..1c4f56e8e1aaa 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,7 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added support for general mappings being returned from `training_step` when using manual optimization ([#21011](https://github.com/Lightning-AI/pytorch-lightning/pull/21011)) + ### Changed diff --git a/src/lightning/pytorch/loops/optimization/manual.py b/src/lightning/pytorch/loops/optimization/manual.py index e1aabcbf42976..10bd5b8b1c666 100644 --- a/src/lightning/pytorch/loops/optimization/manual.py +++ b/src/lightning/pytorch/loops/optimization/manual.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict +from collections.abc import Mapping from contextlib import suppress from dataclasses import dataclass, field from typing import Any @@ -45,7 +46,7 @@ class ManualResult(OutputResult): @classmethod def from_training_step_output(cls, training_step_output: STEP_OUTPUT) -> "ManualResult": extra = {} - if isinstance(training_step_output, dict): + if isinstance(training_step_output, Mapping): extra = training_step_output.copy() elif isinstance(training_step_output, Tensor): extra = {"loss": training_step_output} diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py index 3f89e1459298d..dd8042ecf2058 100644 --- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py +++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py @@ -304,8 +304,36 @@ def on_train_epoch_end(self, *_, **__): trainer.fit(model) +class CustomMapping(collections.abc.Mapping): + """A custom implementation of Mapping for testing purposes.""" + + def __init__(self, *args, **kwargs): + self._store = dict(*args, **kwargs) + + def __getitem__(self, key): + return self._store[key] + + def __iter__(self): + return iter(self._store) + + def __len__(self): + return len(self._store) + + def __repr__(self): + return f"{self.__class__.__name__}({self._store})" + + def __copy__(self): + cls = self.__class__ + new_obj = cls(self._store.copy()) + return new_obj + + def copy(self): + return self.__copy__() + + @RunIf(min_cuda_gpus=1) -def test_multiple_optimizers_step(tmp_path): +@pytest.mark.parametrize("dicttype", [dict, CustomMapping]) +def test_multiple_optimizers_step(tmp_path, dicttype): """Tests that `step` works with several optimizers.""" class TestModel(ManualOptModel): @@ -335,7 +363,7 @@ def training_step(self, batch, batch_idx): opt_b.step() opt_b.zero_grad() - return {"loss1": loss_1.detach(), "loss2": loss_2.detach()} + return dicttype(loss1=loss_1.detach(), loss2=loss_2.detach()) # sister test: tests/plugins/test_amp_plugins.py::test_amp_gradient_unscale def on_after_backward(self) -> None: