Skip to content
Closed
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/lightning/pytorch/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,28 @@ def finalize(self, status: str) -> None:
if self._checkpoint_callback and self._experiment is not None:
self._scan_and_log_checkpoints(self._checkpoint_callback)

def on_log_checkpoint_artifact(
self,
artifact: "wandb.Artifact",
checkpoint_timestamp: float,
path: _PATH,
score: float,
tag: str,
) -> None:
"""Override this to provide custom artifact logging behavior.

By default, adds `path` as a file called "model.ckpt" to the `artifact`.

Args:
artifact: The wandb artifact for this checkpoint
checkpoint_timestamp: The timestamp the checkpoint was last created or modified.
path: The absolute path or URI to the checkpoint.
score: The score associated with the checkpoint.
tag: The tag associated with the checkpoint.

"""
artifact.add_file(path, name="model.ckpt")

def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
import wandb

Expand Down Expand Up @@ -665,7 +687,13 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
if not self._checkpoint_name:
self._checkpoint_name = f"model-{self.experiment.id}"
artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
artifact.add_file(p, name="model.ckpt")
self.on_log_checkpoint_artifact(
artifact=artifact,
checkpoint_timestamp=t,
path=p,
score=s,
tag=tag,
)
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
self.experiment.log_artifact(artifact, aliases=aliases)
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
Expand Down