Skip to content

Commit 2e1482e

Browse files
Matthew HoffmanBorda
authored andcommitted
fix: always call WandbLogger.experiment first in _call_setup_hook to ensure tensorboard logs can sync to wandb (#20610)
(cherry picked from commit 93d707d)
1 parent 4c94b4e commit 2e1482e

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))
1818

19+
- Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610))
20+
1921

2022
## [2.5.0] - 2024-12-19
2123

src/lightning/pytorch/trainer/call.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import lightning.pytorch as pl
2222
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
2323
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
24+
from lightning.pytorch.loggers import WandbLogger
2425
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
2526
from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal
2627
from lightning.pytorch.trainer.states import TrainerStatus
@@ -91,8 +92,12 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None:
9192
if isinstance(module, _DeviceDtypeModuleMixin):
9293
module._device = trainer.strategy.root_device
9394

95+
# wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb:
96+
# https://github.com/wandb/wandb/issues/1782#issuecomment-779161203
97+
loggers = sorted(trainer.loggers, key=lambda logger: not isinstance(logger, WandbLogger))
98+
9499
# Trigger lazy creation of experiment in loggers so loggers have their metadata available
95-
for logger in trainer.loggers:
100+
for logger in loggers:
96101
if hasattr(logger, "experiment"):
97102
_ = logger.experiment
98103

tests/tests_pytorch/loggers/test_wandb.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from lightning.pytorch.callbacks import ModelCheckpoint
2525
from lightning.pytorch.cli import LightningCLI
2626
from lightning.pytorch.demos.boring_classes import BoringModel
27-
from lightning.pytorch.loggers import WandbLogger
27+
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
2828
from lightning.pytorch.utilities.exceptions import MisconfigurationException
2929
from tests_pytorch.test_cli import _xfail_python_ge_3_11_9
3030

@@ -133,6 +133,43 @@ def test_wandb_logger_init_before_spawn(wandb_mock):
133133
assert logger._experiment is not None
134134

135135

136+
def test_wandb_logger_experiment_called_first(wandb_mock, tmp_path):
137+
wandb_experiment_called = False
138+
139+
def tensorboard_experiment_side_effect() -> mock.MagicMock:
140+
nonlocal wandb_experiment_called
141+
assert wandb_experiment_called
142+
return mock.MagicMock()
143+
144+
def wandb_experiment_side_effect() -> mock.MagicMock:
145+
nonlocal wandb_experiment_called
146+
wandb_experiment_called = True
147+
return mock.MagicMock()
148+
149+
with (
150+
mock.patch.object(
151+
TensorBoardLogger,
152+
"experiment",
153+
new_callable=lambda: mock.PropertyMock(side_effect=tensorboard_experiment_side_effect),
154+
),
155+
mock.patch.object(
156+
WandbLogger,
157+
"experiment",
158+
new_callable=lambda: mock.PropertyMock(side_effect=wandb_experiment_side_effect),
159+
),
160+
):
161+
model = BoringModel()
162+
trainer = Trainer(
163+
default_root_dir=tmp_path,
164+
log_every_n_steps=1,
165+
limit_train_batches=0,
166+
limit_val_batches=0,
167+
max_steps=1,
168+
logger=[TensorBoardLogger(tmp_path), WandbLogger(save_dir=tmp_path)],
169+
)
170+
trainer.fit(model)
171+
172+
136173
def test_wandb_pickle(wandb_mock, tmp_path):
137174
"""Verify that pickling trainer with wandb logger works.
138175

0 commit comments

Comments
 (0)