diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 8bc8e45989f77..d6b0ef04c759e 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594)) +- Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610)) + ## [2.5.0] - 2024-12-19 diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 012d1a2152aa3..b5354eb2b08dd 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -21,6 +21,7 @@ import lightning.pytorch as pl from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.pytorch.callbacks import Checkpoint, EarlyStopping +from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal from lightning.pytorch.trainer.states import TrainerStatus @@ -91,8 +92,12 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None: if isinstance(module, _DeviceDtypeModuleMixin): module._device = trainer.strategy.root_device + # wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb: + # https://github.com/wandb/wandb/issues/1782#issuecomment-779161203 + loggers = sorted(trainer.loggers, key=lambda logger: not isinstance(logger, WandbLogger)) + # Trigger lazy creation of experiment in loggers so loggers have their metadata available - for logger in trainer.loggers: + for logger in loggers: if hasattr(logger, "experiment"): _ = logger.experiment diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index f3d82b0582be2..52ad03bd994b4 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -24,7 +24,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.test_cli import _xfail_python_ge_3_11_9 @@ -133,6 +133,43 @@ def test_wandb_logger_init_before_spawn(wandb_mock): assert logger._experiment is not None +def test_wandb_logger_experiment_called_first(wandb_mock, tmp_path): + wandb_experiment_called = False + + def tensorboard_experiment_side_effect() -> mock.MagicMock: + nonlocal wandb_experiment_called + assert wandb_experiment_called + return mock.MagicMock() + + def wandb_experiment_side_effect() -> mock.MagicMock: + nonlocal wandb_experiment_called + wandb_experiment_called = True + return mock.MagicMock() + + with ( + mock.patch.object( + TensorBoardLogger, + "experiment", + new_callable=lambda: mock.PropertyMock(side_effect=tensorboard_experiment_side_effect), + ), + mock.patch.object( + WandbLogger, + "experiment", + new_callable=lambda: mock.PropertyMock(side_effect=wandb_experiment_side_effect), + ), + ): + model = BoringModel() + trainer = Trainer( + default_root_dir=tmp_path, + log_every_n_steps=1, + limit_train_batches=0, + limit_val_batches=0, + max_steps=1, + logger=[TensorBoardLogger(tmp_path), WandbLogger(save_dir=tmp_path)], + ) + trainer.fit(model) + + def test_wandb_pickle(wandb_mock, tmp_path): """Verify that pickling trainer with wandb logger works.