Skip to content

Commit 7bc2a65

Browse files
fix(loggers): add best and latest aliases to wandb artifact in WandbLogger (#17121)
Co-authored-by: awaelchli <[email protected]>
1 parent 60c9f24 commit 7bc2a65

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929

3030
### Fixed
3131

32+
- Fixed WandbLogger not showing "best" aliases for model checkpoints when `ModelCheckpoint(save_top_k>0)` is used ([#17121](https://github.com/Lightning-AI/lightning/pull/17121))
3233

3334

3435

src/lightning/pytorch/loggers/wandb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
599599
self._checkpoint_name = f"model-{self.experiment.id}"
600600
artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
601601
artifact.add_file(p, name="model.ckpt")
602-
self.experiment.log_artifact(artifact, aliases=[tag])
602+
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
603+
self.experiment.log_artifact(artifact, aliases=aliases)
603604
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
604605
self._logged_model_time[p] = t

tests/tests_pytorch/loggers/test_wandb.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,78 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
277277
},
278278
)
279279

280+
# Test wandb artifact with checkpoint_callback top_k logging latest
281+
wandb.init().log_artifact.reset_mock()
282+
wandb.init.reset_mock()
283+
wandb.Artifact.reset_mock()
284+
logger = WandbLogger(save_dir=tmpdir, log_model=True)
285+
logger.experiment.id = "1"
286+
logger.experiment.name = "run_name"
287+
trainer = Trainer(
288+
default_root_dir=tmpdir,
289+
logger=logger,
290+
max_epochs=3,
291+
limit_train_batches=3,
292+
limit_val_batches=3,
293+
callbacks=[ModelCheckpoint(monitor="step", save_top_k=2)],
294+
)
295+
trainer.fit(model)
296+
wandb.Artifact.assert_called_with(
297+
name="model-1",
298+
type="model",
299+
metadata={
300+
"score": 6,
301+
"original_filename": "epoch=1-step=6-v5.ckpt",
302+
"ModelCheckpoint": {
303+
"monitor": "step",
304+
"mode": "min",
305+
"save_last": None,
306+
"save_top_k": 2,
307+
"save_weights_only": False,
308+
"_every_n_train_steps": 0,
309+
},
310+
},
311+
)
312+
wandb.init().log_artifact.assert_called_with(wandb.Artifact(), aliases=["latest"])
313+
314+
# Test wandb artifact with checkpoint_callback top_k logging latest and best
315+
wandb.init().log_artifact.reset_mock()
316+
wandb.init.reset_mock()
317+
wandb.Artifact.reset_mock()
318+
logger = WandbLogger(save_dir=tmpdir, log_model=True)
319+
logger.experiment.id = "1"
320+
logger.experiment.name = "run_name"
321+
trainer = Trainer(
322+
default_root_dir=tmpdir,
323+
logger=logger,
324+
max_epochs=3,
325+
limit_train_batches=3,
326+
limit_val_batches=3,
327+
callbacks=[
328+
ModelCheckpoint(
329+
monitor="step",
330+
)
331+
],
332+
)
333+
trainer.fit(model)
334+
wandb.Artifact.assert_called_with(
335+
name="model-1",
336+
type="model",
337+
metadata={
338+
"score": 3,
339+
"original_filename": "epoch=0-step=3-v1.ckpt",
340+
"ModelCheckpoint": {
341+
"monitor": "step",
342+
"mode": "min",
343+
"save_last": None,
344+
"save_top_k": 1,
345+
"save_weights_only": False,
346+
"_every_n_train_steps": 0,
347+
},
348+
},
349+
)
350+
wandb.init().log_artifact.assert_called_with(wandb.Artifact(), aliases=["latest", "best"])
351+
280352

281353
@mock.patch("lightning.pytorch.loggers.wandb.Run", new=mock.Mock)
282354
@mock.patch("lightning.pytorch.loggers.wandb.wandb")

0 commit comments

Comments
 (0)