Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Change `wandb` default x-axis to `tensorboard`'s `global_step` when `sync_tensorboard=True` ([#20611](https://github.com/Lightning-AI/pytorch-lightning/pull/20611))

### Removed

### Fixed
Expand Down
9 changes: 6 additions & 3 deletions src/lightning/pytorch/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,11 @@ def experiment(self) -> Union["Run", "RunDisabled"]:
if isinstance(self._experiment, (Run, RunDisabled)) and getattr(
self._experiment, "define_metric", None
):
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
if self._wandb_init.get("sync_tensorboard"):
self._experiment.define_metric("*", step_metric="global_step")
else:
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)

return self._experiment

Expand All @@ -434,7 +437,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"

metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
if step is not None:
if step is not None and not self._wandb_init.get("sync_tensorboard"):
self.experiment.log(dict(metrics, **{"trainer/global_step": step}))
else:
self.experiment.log(metrics)
Expand Down
1 change: 1 addition & 0 deletions tests/tests_pytorch/loggers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class RunType: # to make isinstance checks pass
watch=Mock(),
log_artifact=Mock(),
use_artifact=Mock(),
define_metric=Mock(),
id="run_id",
)

Expand Down
18 changes: 18 additions & 0 deletions tests/tests_pytorch/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,24 @@ def test_wandb_logger_init(wandb_mock):
assert logger.version == wandb_mock.init().id


def test_wandb_logger_sync_tensorboard(wandb_mock):
logger = WandbLogger(sync_tensorboard=True)
wandb_mock.run = None
logger.experiment

# test that tensorboard's global_step is set as the default x-axis if sync_tensorboard=True
wandb_mock.init.return_value.define_metric.assert_called_once_with("*", step_metric="global_step")


def test_wandb_logger_sync_tensorboard_log_metrics(wandb_mock):
logger = WandbLogger(sync_tensorboard=True)
metrics = {"loss": 1e-3, "accuracy": 0.99}
logger.log_metrics(metrics)

# test that trainer/global_step is not added to the logged metrics if sync_tensorboard=True
wandb_mock.run.log.assert_called_once_with(metrics)


def test_wandb_logger_init_before_spawn(wandb_mock):
logger = WandbLogger()
assert logger._experiment is None
Expand Down
Loading