Skip to content

Commit 8cffc0f

Browse files
carmoccarohitgr7
authored andcommitted
Avoid in-place ops during logging result updates (#11401)
Co-authored-by: rohitgr7 <[email protected]>
1 parent 18e95e0 commit 8cffc0f

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7-
## [1.5.9] - 2022-01-11
7+
## [1.5.9] - 2022-01-18
88

99
### Fixed
1010

1111
- Pin sphinx-autodoc-typehints with <v1.15 ([#11400](https://github.com/PyTorchLightning/pytorch-lightning/pull/11400))
1212
- Skip testing with PyTorch 1.7 and Python 3.9 on Ubuntu ([#11217](https://github.com/PyTorchLightning/pytorch-lightning/pull/11217))
13+
- Fixed type promotion when tensors of higher category than float are logged ([#11401](https://github.com/PyTorchLightning/pytorch-lightning/pull/11401))
1314

1415

1516
## [1.5.8] - 2022-01-05

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
217217
# do not set a dtype in case the default dtype was changed
218218
self.add_state("value", torch.tensor(default), dist_reduce_fx=torch.sum)
219219
if self.meta.is_mean_reduction:
220+
self.cumulated_batch_size: torch.Tensor
220221
self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum)
221222

222223
def update(self, value: _IN_METRIC, batch_size: int) -> None:
@@ -240,12 +241,13 @@ def update(self, value: _IN_METRIC, batch_size: int) -> None:
240241

241242
# perform accumulation with reduction
242243
if self.meta.is_mean_reduction:
243-
self.value += value.mean() * batch_size
244-
self.cumulated_batch_size += batch_size
244+
# do not use `+=` as it doesn't do type promotion
245+
self.value = self.value + value.mean() * batch_size
246+
self.cumulated_batch_size = self.cumulated_batch_size + batch_size
245247
elif self.meta.is_max_reduction or self.meta.is_min_reduction:
246248
self.value = self.meta.reduce_fx(self.value, value.mean())
247249
elif self.meta.is_sum_reduction:
248-
self.value += value.mean()
250+
self.value = self.value + value.mean()
249251
else:
250252
self.value = value
251253
self._forward_cache = value._forward_cache

tests/core/test_metric_result_integration.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,26 @@ def test_metric_result_respects_dtype(floating_dtype):
593593
torch.set_default_dtype(torch.float)
594594

595595

596+
@pytest.mark.parametrize("reduce_fx", ("mean", sum))
597+
def test_metric_result_dtype_promotion(reduce_fx):
598+
metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx)
599+
metadata.sync = _Sync()
600+
rm = ResultMetric(metadata, is_tensor=True)
601+
assert rm.value.dtype == torch.float
602+
603+
# log a double
604+
rm.update(torch.tensor(0, dtype=torch.double), 1)
605+
# `rm.value.dtype` is promoted
606+
assert rm.value.dtype == torch.double
607+
# log a float
608+
rm.update(torch.tensor(0, dtype=torch.float), 1)
609+
# the previous dtype stays
610+
assert rm.value.dtype == torch.double
611+
612+
total = rm.compute()
613+
assert total.dtype == torch.double
614+
615+
596616
@pytest.mark.parametrize(["reduce_fx", "expected"], [(max, -2), (min, 2)])
597617
def test_result_metric_max_min(reduce_fx, expected):
598618
metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx)

0 commit comments

Comments
 (0)