diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 4025f2cd18004..b115b4012f8e5 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -637,6 +637,28 @@ def finalize(self, status: str) -> None: if self._checkpoint_callback and self._experiment is not None: self._scan_and_log_checkpoints(self._checkpoint_callback) + def on_log_checkpoint_artifact( + self, + artifact: "wandb.Artifact", + checkpoint_timestamp: float, + path: _PATH, + score: float, + tag: str, + ) -> None: + """Override this to provide custom artifact logging behavior. + + By default, adds `path` as a file called "model.ckpt" to the `artifact`. + + Args: + artifact: The wandb artifact for this checkpoint + checkpoint_timestamp: The timestamp the checkpoint was last created or modified. + path: The absolute path or URI to the checkpoint. + score: The score associated with the checkpoint. + tag: The tag associated with the checkpoint. + + """ + artifact.add_file(path, name="model.ckpt") + def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: import wandb @@ -665,7 +687,13 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non if not self._checkpoint_name: self._checkpoint_name = f"model-{self.experiment.id}" artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata) - artifact.add_file(p, name="model.ckpt") + self.on_log_checkpoint_artifact( + artifact=artifact, + checkpoint_timestamp=t, + path=p, + score=s, + tag=tag, + ) aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] self.experiment.log_artifact(artifact, aliases=aliases) # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)