From c282521a5a0ef6bdf361d1ee455033418f54b104 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sun, 2 Mar 2025 22:46:43 -0600 Subject: [PATCH 1/2] refactor: Set tensorboard's global_step as the default wandb x-axis if sync_tensorboard=True there isn't a need to create another global_step if you are syncing from tensorboard --- src/lightning/pytorch/loggers/wandb.py | 9 ++++++--- tests/tests_pytorch/loggers/conftest.py | 1 + tests/tests_pytorch/loggers/test_wandb.py | 18 ++++++++++++++++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 2429748f73179..0ea32b97c46d1 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -410,8 +410,11 @@ def experiment(self) -> Union["Run", "RunDisabled"]: if isinstance(self._experiment, (Run, RunDisabled)) and getattr( self._experiment, "define_metric", None ): - self._experiment.define_metric("trainer/global_step") - self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) + if self._wandb_init.get("sync_tensorboard"): + self._experiment.define_metric("*", step_metric="global_step") + else: + self._experiment.define_metric("trainer/global_step") + self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) return self._experiment @@ -434,7 +437,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) - if step is not None: + if step is not None and not self._wandb_init.get("sync_tensorboard"): self.experiment.log(dict(metrics, **{"trainer/global_step": step})) else: self.experiment.log(metrics) diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index ab1149ca9651a..033275a9fec62 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -55,6 +55,7 @@ class RunType: # to make isinstance checks pass watch=Mock(), log_artifact=Mock(), use_artifact=Mock(), + define_metric=Mock(), id="run_id", ) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index f3d82b0582be2..90ba3794f8f35 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -126,6 +126,24 @@ def test_wandb_logger_init(wandb_mock): assert logger.version == wandb_mock.init().id +def test_wandb_logger_sync_tensorboard(wandb_mock): + logger = WandbLogger(sync_tensorboard=True) + wandb_mock.run = None + logger.experiment + + # test that tensorboard's global_step is set as the default x-axis if sync_tensorboard=True + wandb_mock.init.return_value.define_metric.assert_called_once_with("*", step_metric="global_step") + + +def test_wandb_logger_sync_tensorboard_log_metrics(wandb_mock): + logger = WandbLogger(sync_tensorboard=True) + metrics = {"loss": 1e-3, "accuracy": 0.99} + logger.log_metrics(metrics) + + # test that trainer/global_step is not added to the logged metrics if sync_tensorboard=True + wandb_mock.run.log.assert_called_once_with(metrics) + + def test_wandb_logger_init_before_spawn(wandb_mock): logger = WandbLogger() assert logger._experiment is None From 6bb8e5665c145d4387d09f1d7b1220e54ee55a9c Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sun, 2 Mar 2025 23:06:59 -0600 Subject: [PATCH 2/2] chore: Update changelog for #20611 --- src/lightning/pytorch/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 8bc8e45989f77..aa960043d37d9 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Change `wandb` default x-axis to `tensorboard`'s `global_step` when `sync_tensorboard=True` ([#20611](https://github.com/Lightning-AI/pytorch-lightning/pull/20611)) + ### Removed ### Fixed