Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed logger_connector has edge case where step can be a float ([#20692](https://github.com/Lightning-AI/pytorch-lightning/issues/20692))


---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,13 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
scalar_metrics = convert_tensors_to_scalars(metrics)

if step is None:
step = scalar_metrics.pop("step", None)

if step is None:
# added metrics for convenience
scalar_metrics.setdefault("epoch", self.trainer.current_epoch)
step = self.trainer.fit_loop.epoch_loop._batches_that_stepped
step_metric = scalar_metrics.pop("step", None)
if step_metric is not None:
step = int(step_metric)
else:
# added metrics for convenience
scalar_metrics.setdefault("epoch", self.trainer.current_epoch)
step = self.trainer.fit_loop.epoch_loop._batches_that_stepped

# log actual metrics
for logger in self.trainer.loggers:
Expand Down
61 changes: 61 additions & 0 deletions tests/tests_pytorch/trainer/connectors/test_logger_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from unittest.mock import MagicMock, patch

from lightning.pytorch import Trainer
from lightning.pytorch.loggers import Logger
from lightning.pytorch.trainer.connectors.logger_connector import _LoggerConnector


@patch("lightning.pytorch.trainer.connectors.logger_connector.logger_connector.convert_tensors_to_scalars")
def test_uses_provided_step(mock_convert):
"""Test that the LoggerConnector uses explicitly provided step to log metrics."""

trainer = MagicMock(spec=Trainer)
trainer.loggers = [logger := MagicMock(spec=Logger)]
connector = _LoggerConnector(trainer)
mock_convert.return_value.pop.return_value = step = 42

connector.log_metrics((metrics := {"some_metric": 123}), step=step)

assert connector._logged_metrics == metrics
mock_convert.assert_called_once_with(metrics)
logger.log_metrics.assert_called_once_with(metrics=mock_convert.return_value, step=step)
logger.save.assert_called_once_with()


@patch("lightning.pytorch.trainer.connectors.logger_connector.logger_connector.convert_tensors_to_scalars")
def test_uses_step_metric(mock_convert):
"""Test that the LoggerConnector uses explicitly provided step metric to log metrics."""

trainer = MagicMock(spec=Trainer)
trainer.loggers = [logger := MagicMock(spec=Logger)]
connector = _LoggerConnector(trainer)
mock_convert.return_value.pop.return_value = step = 42.0

metrics = {"some_metric": 123}
connector.log_metrics(logged_metrics := {**metrics, "step": step})

assert connector._logged_metrics == logged_metrics
mock_convert.assert_called_once_with(logged_metrics)
logger.log_metrics.assert_called_once_with(metrics=mock_convert.return_value, step=int(step))
logger.save.assert_called_once_with()


@patch("lightning.pytorch.trainer.connectors.logger_connector.logger_connector.convert_tensors_to_scalars")
def test_uses_batches_that_stepped(mock_convert):
"""Test that the LoggerConnector uses implicitly provided batches_that_stepped to log metrics."""

trainer = MagicMock(spec=Trainer)
trainer.fit_loop = MagicMock()
trainer.loggers = [logger := MagicMock(spec=Logger)]
connector = _LoggerConnector(trainer)
mock_convert.return_value.pop.return_value = None

connector.log_metrics(metrics := {"some_metric": 123})

assert connector._logged_metrics == metrics
mock_convert.assert_called_once_with(metrics)
logger.log_metrics.assert_called_once_with(
metrics=mock_convert.return_value, step=trainer.fit_loop.epoch_loop._batches_that_stepped
)
logger.save.assert_called_once_with()
mock_convert.return_value.setdefault.assert_called_once_with("epoch", trainer.current_epoch)
Loading