Skip to content

Commit 3d39824

Browse files
clumsyazzhipa
andauthored
fix: convert step to int when logging (#20830)
fix: convert step to int when logging (#20692) Co-authored-by: Alexander Zhipa <[email protected]>
1 parent 01ba7a1 commit 3d39824

File tree

3 files changed

+69
-7
lines changed

3 files changed

+69
-7
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424

2525
### Fixed
2626

27-
-
27+
- Fixed logger_connector has edge case where step can be a float ([#20692](https://github.com/Lightning-AI/pytorch-lightning/issues/20692))
2828

2929

3030
---

src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,13 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
106106
scalar_metrics = convert_tensors_to_scalars(metrics)
107107

108108
if step is None:
109-
step = scalar_metrics.pop("step", None)
110-
111-
if step is None:
112-
# added metrics for convenience
113-
scalar_metrics.setdefault("epoch", self.trainer.current_epoch)
114-
step = self.trainer.fit_loop.epoch_loop._batches_that_stepped
109+
step_metric = scalar_metrics.pop("step", None)
110+
if step_metric is not None:
111+
step = int(step_metric)
112+
else:
113+
# added metrics for convenience
114+
scalar_metrics.setdefault("epoch", self.trainer.current_epoch)
115+
step = self.trainer.fit_loop.epoch_loop._batches_that_stepped
115116

116117
# log actual metrics
117118
for logger in self.trainer.loggers:
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
from lightning.pytorch import Trainer
4+
from lightning.pytorch.loggers import Logger
5+
from lightning.pytorch.trainer.connectors.logger_connector import _LoggerConnector
6+
7+
8+
@patch("lightning.pytorch.trainer.connectors.logger_connector.logger_connector.convert_tensors_to_scalars")
9+
def test_uses_provided_step(mock_convert):
10+
"""Test that the LoggerConnector uses explicitly provided step to log metrics."""
11+
12+
trainer = MagicMock(spec=Trainer)
13+
trainer.loggers = [logger := MagicMock(spec=Logger)]
14+
connector = _LoggerConnector(trainer)
15+
mock_convert.return_value.pop.return_value = step = 42
16+
17+
connector.log_metrics((metrics := {"some_metric": 123}), step=step)
18+
19+
assert connector._logged_metrics == metrics
20+
mock_convert.assert_called_once_with(metrics)
21+
logger.log_metrics.assert_called_once_with(metrics=mock_convert.return_value, step=step)
22+
logger.save.assert_called_once_with()
23+
24+
25+
@patch("lightning.pytorch.trainer.connectors.logger_connector.logger_connector.convert_tensors_to_scalars")
26+
def test_uses_step_metric(mock_convert):
27+
"""Test that the LoggerConnector uses explicitly provided step metric to log metrics."""
28+
29+
trainer = MagicMock(spec=Trainer)
30+
trainer.loggers = [logger := MagicMock(spec=Logger)]
31+
connector = _LoggerConnector(trainer)
32+
mock_convert.return_value.pop.return_value = step = 42.0
33+
34+
metrics = {"some_metric": 123}
35+
connector.log_metrics(logged_metrics := {**metrics, "step": step})
36+
37+
assert connector._logged_metrics == logged_metrics
38+
mock_convert.assert_called_once_with(logged_metrics)
39+
logger.log_metrics.assert_called_once_with(metrics=mock_convert.return_value, step=int(step))
40+
logger.save.assert_called_once_with()
41+
42+
43+
@patch("lightning.pytorch.trainer.connectors.logger_connector.logger_connector.convert_tensors_to_scalars")
44+
def test_uses_batches_that_stepped(mock_convert):
45+
"""Test that the LoggerConnector uses implicitly provided batches_that_stepped to log metrics."""
46+
47+
trainer = MagicMock(spec=Trainer)
48+
trainer.fit_loop = MagicMock()
49+
trainer.loggers = [logger := MagicMock(spec=Logger)]
50+
connector = _LoggerConnector(trainer)
51+
mock_convert.return_value.pop.return_value = None
52+
53+
connector.log_metrics(metrics := {"some_metric": 123})
54+
55+
assert connector._logged_metrics == metrics
56+
mock_convert.assert_called_once_with(metrics)
57+
logger.log_metrics.assert_called_once_with(
58+
metrics=mock_convert.return_value, step=trainer.fit_loop.epoch_loop._batches_that_stepped
59+
)
60+
logger.save.assert_called_once_with()
61+
mock_convert.return_value.setdefault.assert_called_once_with("epoch", trainer.current_epoch)

0 commit comments

Comments
 (0)