Skip to content

Commit 791753b

Browse files
SkafteNickipre-commit-ci[bot]bhimrazydeependujha
authored
Allow training_step in manual optimization to return general mappings (#21011)
* change dict to mapping * add a bit of testing * changelog * use custom mapping * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update tests/tests_pytorch/trainer/optimization/test_manual_optimization.py Co-authored-by: Bhimraj Yadav <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Bhimraj Yadav <[email protected]> Co-authored-by: Deependu Jha <[email protected]>
1 parent 8d847fd commit 791753b

File tree

3 files changed

+34
-4
lines changed

3 files changed

+34
-4
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
### Added
1212

13-
-
13+
- Added support for general mappings being returned from `training_step` when using manual optimization ([#21011](https://github.com/Lightning-AI/pytorch-lightning/pull/21011))
14+
1415

1516

1617
### Changed

src/lightning/pytorch/loops/optimization/manual.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from collections import OrderedDict
15+
from collections.abc import Mapping
1516
from contextlib import suppress
1617
from dataclasses import dataclass, field
1718
from typing import Any
@@ -45,7 +46,7 @@ class ManualResult(OutputResult):
4546
@classmethod
4647
def from_training_step_output(cls, training_step_output: STEP_OUTPUT) -> "ManualResult":
4748
extra = {}
48-
if isinstance(training_step_output, dict):
49+
if isinstance(training_step_output, Mapping):
4950
extra = training_step_output.copy()
5051
elif isinstance(training_step_output, Tensor):
5152
extra = {"loss": training_step_output}

tests/tests_pytorch/trainer/optimization/test_manual_optimization.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,36 @@ def on_train_epoch_end(self, *_, **__):
304304
trainer.fit(model)
305305

306306

307+
class CustomMapping(collections.abc.Mapping):
308+
"""A custom implementation of Mapping for testing purposes."""
309+
310+
def __init__(self, *args, **kwargs):
311+
self._store = dict(*args, **kwargs)
312+
313+
def __getitem__(self, key):
314+
return self._store[key]
315+
316+
def __iter__(self):
317+
return iter(self._store)
318+
319+
def __len__(self):
320+
return len(self._store)
321+
322+
def __repr__(self):
323+
return f"{self.__class__.__name__}({self._store})"
324+
325+
def __copy__(self):
326+
cls = self.__class__
327+
new_obj = cls(self._store.copy())
328+
return new_obj
329+
330+
def copy(self):
331+
return self.__copy__()
332+
333+
307334
@RunIf(min_cuda_gpus=1)
308-
def test_multiple_optimizers_step(tmp_path):
335+
@pytest.mark.parametrize("dicttype", [dict, CustomMapping])
336+
def test_multiple_optimizers_step(tmp_path, dicttype):
309337
"""Tests that `step` works with several optimizers."""
310338

311339
class TestModel(ManualOptModel):
@@ -335,7 +363,7 @@ def training_step(self, batch, batch_idx):
335363
opt_b.step()
336364
opt_b.zero_grad()
337365

338-
return {"loss1": loss_1.detach(), "loss2": loss_2.detach()}
366+
return dicttype(loss1=loss_1.detach(), loss2=loss_2.detach())
339367

340368
# sister test: tests/plugins/test_amp_plugins.py::test_amp_gradient_unscale
341369
def on_after_backward(self) -> None:

0 commit comments

Comments
 (0)