|
24 | 24 | from lightning.pytorch.callbacks import ModelCheckpoint
|
25 | 25 | from lightning.pytorch.cli import LightningCLI
|
26 | 26 | from lightning.pytorch.demos.boring_classes import BoringModel
|
27 |
| -from lightning.pytorch.loggers import WandbLogger |
| 27 | +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger |
28 | 28 | from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
29 | 29 | from tests_pytorch.test_cli import _xfail_python_ge_3_11_9
|
30 | 30 |
|
@@ -133,6 +133,43 @@ def test_wandb_logger_init_before_spawn(wandb_mock):
|
133 | 133 | assert logger._experiment is not None
|
134 | 134 |
|
135 | 135 |
|
| 136 | +def test_wandb_logger_experiment_called_first(wandb_mock, tmp_path): |
| 137 | + wandb_experiment_called = False |
| 138 | + |
| 139 | + def tensorboard_experiment_side_effect() -> mock.MagicMock: |
| 140 | + nonlocal wandb_experiment_called |
| 141 | + assert wandb_experiment_called |
| 142 | + return mock.MagicMock() |
| 143 | + |
| 144 | + def wandb_experiment_side_effect() -> mock.MagicMock: |
| 145 | + nonlocal wandb_experiment_called |
| 146 | + wandb_experiment_called = True |
| 147 | + return mock.MagicMock() |
| 148 | + |
| 149 | + with ( |
| 150 | + mock.patch.object( |
| 151 | + TensorBoardLogger, |
| 152 | + "experiment", |
| 153 | + new_callable=lambda: mock.PropertyMock(side_effect=tensorboard_experiment_side_effect), |
| 154 | + ), |
| 155 | + mock.patch.object( |
| 156 | + WandbLogger, |
| 157 | + "experiment", |
| 158 | + new_callable=lambda: mock.PropertyMock(side_effect=wandb_experiment_side_effect), |
| 159 | + ), |
| 160 | + ): |
| 161 | + model = BoringModel() |
| 162 | + trainer = Trainer( |
| 163 | + default_root_dir=tmp_path, |
| 164 | + log_every_n_steps=1, |
| 165 | + limit_train_batches=0, |
| 166 | + limit_val_batches=0, |
| 167 | + max_steps=1, |
| 168 | + logger=[TensorBoardLogger(tmp_path), WandbLogger(save_dir=tmp_path)], |
| 169 | + ) |
| 170 | + trainer.fit(model) |
| 171 | + |
| 172 | + |
136 | 173 | def test_wandb_pickle(wandb_mock, tmp_path):
|
137 | 174 | """Verify that pickling trainer with wandb logger works.
|
138 | 175 |
|
|
0 commit comments