Skip to content

Commit 8dbd103

Browse files
author
Matthew Hoffman
committed
fix: always call WandbLogger.experiment first in _call_setup_hook to ensure tensorboard logs sync to wandb
wandb/wandb#1782 (comment)
1 parent 1f5add3 commit 8dbd103

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

src/lightning/pytorch/trainer/call.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import lightning.pytorch as pl
2222
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
2323
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
24+
from lightning.pytorch.loggers import WandbLogger
2425
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
2526
from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal
2627
from lightning.pytorch.trainer.states import TrainerStatus
@@ -91,8 +92,12 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None:
9192
if isinstance(module, _DeviceDtypeModuleMixin):
9293
module._device = trainer.strategy.root_device
9394

95+
# wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb:
96+
# https://github.com/wandb/wandb/issues/1782#issuecomment-779161203
97+
loggers = sorted(trainer.loggers, key=lambda logger: not isinstance(logger, WandbLogger))
98+
9499
# Trigger lazy creation of experiment in loggers so loggers have their metadata available
95-
for logger in trainer.loggers:
100+
for logger in loggers:
96101
if hasattr(logger, "experiment"):
97102
_ = logger.experiment
98103

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
RandomIterableDataset,
5050
RandomIterableDatasetWithLen,
5151
)
52-
from lightning.pytorch.loggers import TensorBoardLogger
52+
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
5353
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler, _IndexBatchSamplerWrapper
5454
from lightning.pytorch.strategies import DDPStrategy, SingleDeviceStrategy
5555
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher
@@ -1271,6 +1271,43 @@ def training_step(self, *args, **kwargs):
12711271
log_metrics_mock.assert_has_calls(expected_calls)
12721272

12731273

1274+
def test_wandb_logger_experiment_called_first(tmp_path):
1275+
wandb_experiment_called = False
1276+
1277+
def tensorboard_experiment_side_effect() -> mock.MagicMock:
1278+
nonlocal wandb_experiment_called
1279+
assert wandb_experiment_called
1280+
return mock.MagicMock()
1281+
1282+
def wandb_experiment_side_effect() -> mock.MagicMock:
1283+
nonlocal wandb_experiment_called
1284+
wandb_experiment_called = True
1285+
return mock.MagicMock()
1286+
1287+
with (
1288+
mock.patch.object(
1289+
TensorBoardLogger,
1290+
"experiment",
1291+
new_callable=lambda: mock.PropertyMock(side_effect=tensorboard_experiment_side_effect),
1292+
),
1293+
mock.patch.object(
1294+
WandbLogger,
1295+
"experiment",
1296+
new_callable=lambda: mock.PropertyMock(side_effect=wandb_experiment_side_effect),
1297+
),
1298+
):
1299+
model = BoringModel()
1300+
trainer = Trainer(
1301+
default_root_dir=tmp_path,
1302+
log_every_n_steps=1,
1303+
limit_train_batches=0,
1304+
limit_val_batches=0,
1305+
max_steps=1,
1306+
logger=[TensorBoardLogger(tmp_path), WandbLogger(save_dir=tmp_path)],
1307+
)
1308+
trainer.fit(model)
1309+
1310+
12741311
class TestLightningDataModule(LightningDataModule):
12751312
def __init__(self, dataloaders):
12761313
super().__init__()

0 commit comments

Comments
 (0)