Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/lightning/pytorch/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def __init__(
self._prefix = prefix
self._experiment = experiment
self._logged_model_time: Dict[str, float] = {}
self._checkpoint_callback: Optional[ModelCheckpoint] = None
self._checkpoint_callbacks: Dict[int, ModelCheckpoint] = {}

# paths are processed as strings
if save_dir is not None:
Expand Down Expand Up @@ -587,7 +587,7 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
self._scan_and_log_checkpoints(checkpoint_callback)
elif self._log_model is True:
self._checkpoint_callback = checkpoint_callback
self._checkpoint_callbacks[id(checkpoint_callback)] = checkpoint_callback

@staticmethod
@rank_zero_only
Expand Down Expand Up @@ -640,8 +640,9 @@ def finalize(self, status: str) -> None:
# Currently, checkpoints only get logged on success
return
# log checkpoints as artifacts
if self._checkpoint_callback and self._experiment is not None:
self._scan_and_log_checkpoints(self._checkpoint_callback)
if self._experiment is not None:
for checkpoint_callback in self._checkpoint_callbacks.values():
self._scan_and_log_checkpoints(checkpoint_callback)

def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
import wandb
Expand Down
38 changes: 38 additions & 0 deletions tests/tests_pytorch/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,44 @@ def test_wandb_log_model(wandb_mock, tmp_path):
)
wandb_mock.init().log_artifact.assert_called_with(wandb_mock.Artifact(), aliases=["latest", "best"])

# Test wandb artifact with two checkpoint_callbacks
wandb_mock.init().log_artifact.reset_mock()
wandb_mock.init.reset_mock()
wandb_mock.Artifact.reset_mock()
logger = WandbLogger(save_dir=tmp_path, log_model=True)
logger.experiment.id = "1"
logger.experiment.name = "run_name"
trainer = Trainer(
default_root_dir=tmp_path,
logger=logger,
max_epochs=3,
limit_train_batches=3,
limit_val_batches=3,
callbacks=[
ModelCheckpoint(monitor="epoch", save_top_k=2),
ModelCheckpoint(monitor="step", save_top_k=2),
],
)
trainer.fit(model)
for name, val, version in [("epoch", 0, 2), ("step", 3, 3)]:
wandb_mock.Artifact.assert_any_call(
name="model-1",
type="model",
metadata={
"score": val,
"original_filename": f"epoch=0-step=3-v{version}.ckpt",
"ModelCheckpoint": {
"monitor": name,
"mode": "min",
"save_last": None,
"save_top_k": 2,
"save_weights_only": False,
"_every_n_train_steps": 0,
},
},
)
wandb_mock.init().log_artifact.assert_any_call(wandb_mock.Artifact(), aliases=["latest"])


def test_wandb_log_model_with_score(wandb_mock, tmp_path):
"""Test to prevent regression on #15543, ensuring the score is logged as a Python number, not a scalar tensor."""
Expand Down
Loading