Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/lightning/pytorch/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def any_lightning_module_function_or_hook(self):
prefix: A string to put at the beginning of metric keys.
experiment: WandB experiment object. Automatically set when creating a run.
checkpoint_name: Name of the model checkpoint artifact being logged.
add_file_policy: If "mutable", copies file to tempdirectory before upload.
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.

Raises:
Expand All @@ -304,6 +305,7 @@ def __init__(
experiment: Union["Run", "RunDisabled", None] = None,
prefix: str = "",
checkpoint_name: Optional[str] = None,
add_file_policy: Literal["mutable", "immutable"] = "mutable",
**kwargs: Any,
) -> None:
if not _WANDB_AVAILABLE:
Expand All @@ -323,6 +325,7 @@ def __init__(
self._experiment = experiment
self._logged_model_time: dict[str, float] = {}
self._checkpoint_callback: Optional[ModelCheckpoint] = None
self.add_file_policy = add_file_policy

# paths are processed as strings
if save_dir is not None:
Expand Down Expand Up @@ -672,7 +675,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
if not self._checkpoint_name:
self._checkpoint_name = f"model-{self.experiment.id}"
artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
artifact.add_file(p, name="model.ckpt")
artifact.add_file(p, name="model.ckpt", policy=self.add_file_policy)
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
self.experiment.log_artifact(artifact, aliases=aliases)
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
Expand Down
Loading