diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 0ea32b97c46d1..5bef27192b127 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -322,7 +322,7 @@ def __init__( self._prefix = prefix self._experiment = experiment self._logged_model_time: dict[str, float] = {} - self._checkpoint_callback: Optional[ModelCheckpoint] = None + self._checkpoint_callbacks: dict[int, ModelCheckpoint] = {} # paths are processed as strings if save_dir is not None: @@ -591,7 +591,7 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) elif self._log_model is True: - self._checkpoint_callback = checkpoint_callback + self._checkpoint_callbacks[id(checkpoint_callback)] = checkpoint_callback @staticmethod @rank_zero_only @@ -644,8 +644,9 @@ def finalize(self, status: str) -> None: # Currently, checkpoints only get logged on success return # log checkpoints as artifacts - if self._checkpoint_callback and self._experiment is not None: - self._scan_and_log_checkpoints(self._checkpoint_callback) + if self._experiment is not None: + for checkpoint_callback in self._checkpoint_callbacks.values(): + self._scan_and_log_checkpoints(checkpoint_callback) def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: import wandb diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index 7b20423380cb1..35c1917983dcf 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -426,6 +426,44 @@ def test_wandb_log_model(wandb_mock, tmp_path): ) wandb_mock.init().log_artifact.assert_called_with(wandb_mock.Artifact(), aliases=["latest", "best"]) + # Test wandb artifact with two checkpoint_callbacks + wandb_mock.init().log_artifact.reset_mock() + wandb_mock.init.reset_mock() + wandb_mock.Artifact.reset_mock() + logger = WandbLogger(save_dir=tmp_path, log_model=True) + logger.experiment.id = "1" + logger.experiment.name = "run_name" + trainer = Trainer( + default_root_dir=tmp_path, + logger=logger, + max_epochs=3, + limit_train_batches=3, + limit_val_batches=3, + callbacks=[ + ModelCheckpoint(monitor="epoch", save_top_k=2), + ModelCheckpoint(monitor="step", save_top_k=2), + ], + ) + trainer.fit(model) + for name, val, version in [("epoch", 0, 2), ("step", 3, 3)]: + wandb_mock.Artifact.assert_any_call( + name="model-1", + type="model", + metadata={ + "score": val, + "original_filename": f"epoch=0-step=3-v{version}.ckpt", + "ModelCheckpoint": { + "monitor": name, + "mode": "min", + "save_last": None, + "save_top_k": 2, + "save_weights_only": False, + "_every_n_train_steps": 0, + }, + }, + ) + wandb_mock.init().log_artifact.assert_any_call(wandb_mock.Artifact(), aliases=["latest"]) + def test_wandb_log_model_with_score(wandb_mock, tmp_path): """Test to prevent regression on #15543, ensuring the score is logged as a Python number, not a scalar tensor."""