Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added support for general mappings being returned from `training_step` when using manual optimization ([#21011](https://github.com/Lightning-AI/pytorch-lightning/pull/21011))



### Changed
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/loops/optimization/manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from collections.abc import Mapping
from contextlib import suppress
from dataclasses import dataclass, field
from typing import Any
Expand Down Expand Up @@ -45,7 +46,7 @@ class ManualResult(OutputResult):
@classmethod
def from_training_step_output(cls, training_step_output: STEP_OUTPUT) -> "ManualResult":
extra = {}
if isinstance(training_step_output, dict):
if isinstance(training_step_output, Mapping):
extra = training_step_output.copy()
elif isinstance(training_step_output, Tensor):
extra = {"loss": training_step_output}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,28 @@ def on_train_epoch_end(self, *_, **__):
trainer.fit(model)


class CustomMapping(collections.abc.Mapping):
"""A custom implementation of Mapping for testing purposes."""

def __init__(self, *args, **kwargs):
self._store = dict(*args, **kwargs)

def __getitem__(self, key):
return self._store[key]

def __iter__(self):
return iter(self._store)

def __len__(self):
return len(self._store)

def __repr__(self):
return f"{self.__class__.__name__}({self._store})"


@RunIf(min_cuda_gpus=1)
def test_multiple_optimizers_step(tmp_path):
@pytest.mark.parametrize("dicttype", [dict, CustomMapping])
def test_multiple_optimizers_step(tmp_path, dicttype):
"""Tests that `step` works with several optimizers."""

class TestModel(ManualOptModel):
Expand Down Expand Up @@ -335,7 +355,7 @@ def training_step(self, batch, batch_idx):
opt_b.step()
opt_b.zero_grad()

return {"loss1": loss_1.detach(), "loss2": loss_2.detach()}
return dicttype(loss1=loss_1.detach(), loss2=loss_2.detach())

# sister test: tests/plugins/test_amp_plugins.py::test_amp_gradient_unscale
def on_after_backward(self) -> None:
Expand Down
Loading