Skip to content

Commit e87424a

Browse files
Fix Metric.state_dict (#5614)
* Fix Metric.state_dict * Update CHANGELOG.md * Update CHANGELOG.md * Detach tensors in a list if needed Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 5d76b31 commit e87424a

File tree

3 files changed

+37
-4
lines changed

3 files changed

+37
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929
- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620))
3030

3131

32+
- Fixed `Metric`'s `state_dict` not included when child modules ([#5614](https://github.com/PyTorchLightning/pytorch-lightning/pull/5614))
3233

3334

3435
## [1.1.5] - 2021-01-19

pytorch_lightning/metrics/metric.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,23 @@ def persistent(self, mode: bool = False):
284284
for key in self._persistent.keys():
285285
self._persistent[key] = mode
286286

287-
def state_dict(self, *args, **kwargs):
287+
def state_dict(self, destination=None, prefix='', keep_vars=False):
288+
destination = super().state_dict(
289+
destination=destination,
290+
prefix=prefix,
291+
keep_vars=keep_vars
292+
)
288293
# Register metric states to be part of the state_dict
289-
state_dict = super().state_dict()
290294
for key in self._defaults.keys():
291295
if self._persistent[key]:
292296
current_val = getattr(self, key)
293-
state_dict.update({key: current_val})
294-
return state_dict
297+
if not keep_vars:
298+
if torch.is_tensor(current_val):
299+
current_val = current_val.detach()
300+
elif isinstance(current_val, list):
301+
current_val = [
302+
cur_v.detach() if torch.is_tensor(cur_v) else cur_v
303+
for cur_v in current_val
304+
]
305+
destination[prefix + key] = current_val
306+
return destination

tests/metrics/test_metric.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pytest
88
import torch
9+
from torch import nn
910

1011
from pytorch_lightning.metrics.metric import Metric
1112

@@ -201,3 +202,22 @@ def test_state_dict(tmpdir):
201202
assert metric.state_dict() == OrderedDict(x=0)
202203
metric.persistent(False)
203204
assert metric.state_dict() == OrderedDict()
205+
206+
207+
def test_child_metric_state_dict():
208+
""" test that child metric states will be added to parent state dict """
209+
class TestModule(nn.Module):
210+
def __init__(self):
211+
super().__init__()
212+
self.metric = Dummy()
213+
self.metric.add_state('a', torch.tensor(0), persistent=True)
214+
self.metric.add_state('b', [], persistent=True)
215+
self.metric.register_buffer('c', torch.tensor(0))
216+
217+
module = TestModule()
218+
expected_state_dict = {
219+
'metric.a': torch.tensor(0),
220+
'metric.b': [],
221+
'metric.c': torch.tensor(0)
222+
}
223+
assert module.state_dict() == expected_state_dict

0 commit comments

Comments
 (0)