diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 5616defeffc8a..0095367e9187a 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed logger_connector has edge case where step can be a float ([#20692](https://github.com/Lightning-AI/pytorch-lightning/issues/20692)) --- diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py index ffc99a9772469..09addf5a5a58c 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py @@ -106,12 +106,13 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: scalar_metrics = convert_tensors_to_scalars(metrics) if step is None: - step = scalar_metrics.pop("step", None) - - if step is None: - # added metrics for convenience - scalar_metrics.setdefault("epoch", self.trainer.current_epoch) - step = self.trainer.fit_loop.epoch_loop._batches_that_stepped + step_metric = scalar_metrics.pop("step", None) + if step_metric is not None: + step = int(step_metric) + else: + # added metrics for convenience + scalar_metrics.setdefault("epoch", self.trainer.current_epoch) + step = self.trainer.fit_loop.epoch_loop._batches_that_stepped # log actual metrics for logger in self.trainer.loggers: diff --git a/tests/tests_pytorch/trainer/connectors/test_logger_connector.py b/tests/tests_pytorch/trainer/connectors/test_logger_connector.py new file mode 100644 index 0000000000000..7a89efd133235 --- /dev/null +++ b/tests/tests_pytorch/trainer/connectors/test_logger_connector.py @@ -0,0 +1,61 @@ +from unittest.mock import MagicMock, patch + +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import Logger +from lightning.pytorch.trainer.connectors.logger_connector import _LoggerConnector + + +@patch("lightning.pytorch.trainer.connectors.logger_connector.logger_connector.convert_tensors_to_scalars") +def test_uses_provided_step(mock_convert): + """Test that the LoggerConnector uses explicitly provided step to log metrics.""" + + trainer = MagicMock(spec=Trainer) + trainer.loggers = [logger := MagicMock(spec=Logger)] + connector = _LoggerConnector(trainer) + mock_convert.return_value.pop.return_value = step = 42 + + connector.log_metrics((metrics := {"some_metric": 123}), step=step) + + assert connector._logged_metrics == metrics + mock_convert.assert_called_once_with(metrics) + logger.log_metrics.assert_called_once_with(metrics=mock_convert.return_value, step=step) + logger.save.assert_called_once_with() + + +@patch("lightning.pytorch.trainer.connectors.logger_connector.logger_connector.convert_tensors_to_scalars") +def test_uses_step_metric(mock_convert): + """Test that the LoggerConnector uses explicitly provided step metric to log metrics.""" + + trainer = MagicMock(spec=Trainer) + trainer.loggers = [logger := MagicMock(spec=Logger)] + connector = _LoggerConnector(trainer) + mock_convert.return_value.pop.return_value = step = 42.0 + + metrics = {"some_metric": 123} + connector.log_metrics(logged_metrics := {**metrics, "step": step}) + + assert connector._logged_metrics == logged_metrics + mock_convert.assert_called_once_with(logged_metrics) + logger.log_metrics.assert_called_once_with(metrics=mock_convert.return_value, step=int(step)) + logger.save.assert_called_once_with() + + +@patch("lightning.pytorch.trainer.connectors.logger_connector.logger_connector.convert_tensors_to_scalars") +def test_uses_batches_that_stepped(mock_convert): + """Test that the LoggerConnector uses implicitly provided batches_that_stepped to log metrics.""" + + trainer = MagicMock(spec=Trainer) + trainer.fit_loop = MagicMock() + trainer.loggers = [logger := MagicMock(spec=Logger)] + connector = _LoggerConnector(trainer) + mock_convert.return_value.pop.return_value = None + + connector.log_metrics(metrics := {"some_metric": 123}) + + assert connector._logged_metrics == metrics + mock_convert.assert_called_once_with(metrics) + logger.log_metrics.assert_called_once_with( + metrics=mock_convert.return_value, step=trainer.fit_loop.epoch_loop._batches_that_stepped + ) + logger.save.assert_called_once_with() + mock_convert.return_value.setdefault.assert_called_once_with("epoch", trainer.current_epoch)