Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added support for general mappings being returned from `training_step` when using manual optimization ([#21011](https://github.com/Lightning-AI/pytorch-lightning/pull/21011))



### Changed
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/loops/optimization/manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from collections.abc import Mapping
from contextlib import suppress
from dataclasses import dataclass, field
from typing import Any
Expand Down Expand Up @@ -45,7 +46,7 @@ class ManualResult(OutputResult):
@classmethod
def from_training_step_output(cls, training_step_output: STEP_OUTPUT) -> "ManualResult":
extra = {}
if isinstance(training_step_output, dict):
if isinstance(training_step_output, Mapping):
extra = training_step_output.copy()
elif isinstance(training_step_output, Tensor):
extra = {"loss": training_step_output}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,36 @@ def on_train_epoch_end(self, *_, **__):
trainer.fit(model)


class CustomMapping(collections.abc.Mapping):
"""A custom implementation of Mapping for testing purposes."""

def __init__(self, *args, **kwargs):
self._store = dict(*args, **kwargs)

def __getitem__(self, key):
return self._store[key]

def __iter__(self):
return iter(self._store)

def __len__(self):
return len(self._store)

def __repr__(self):
return f"{self.__class__.__name__}({self._store})"

def __copy__(self):
cls = self.__class__
new_obj = cls(self._store.copy())
return new_obj

def copy(self):
return self.__copy__()


@RunIf(min_cuda_gpus=1)
def test_multiple_optimizers_step(tmp_path):
@pytest.mark.parametrize("dicttype", [dict, CustomMapping])
def test_multiple_optimizers_step(tmp_path, dicttype):
"""Tests that `step` works with several optimizers."""

class TestModel(ManualOptModel):
Expand Down Expand Up @@ -335,7 +363,7 @@ def training_step(self, batch, batch_idx):
opt_b.step()
opt_b.zero_grad()

return {"loss1": loss_1.detach(), "loss2": loss_2.detach()}
return dicttype(loss1=loss_1.detach(), loss2=loss_2.detach())

# sister test: tests/plugins/test_amp_plugins.py::test_amp_gradient_unscale
def on_after_backward(self) -> None:
Expand Down
Loading