Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/lightning/pytorch/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def any_lightning_module_function_or_hook(self):
prefix: A string to put at the beginning of metric keys.
experiment: WandB experiment object. Automatically set when creating a run.
checkpoint_name: Name of the model checkpoint artifact being logged.
add_file_policy: If "mutable", copies file to tempdirectory before upload.
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.

Raises:
Expand All @@ -304,6 +305,7 @@ def __init__(
experiment: Union["Run", "RunDisabled", None] = None,
prefix: str = "",
checkpoint_name: Optional[str] = None,
add_file_policy: Literal["mutable", "immutable"] = "mutable",
**kwargs: Any,
) -> None:
if not _WANDB_AVAILABLE:
Expand All @@ -323,6 +325,7 @@ def __init__(
self._experiment = experiment
self._logged_model_time: dict[str, float] = {}
self._checkpoint_callbacks: dict[int, ModelCheckpoint] = {}
self.add_file_policy = add_file_policy

# paths are processed as strings
if save_dir is not None:
Expand Down Expand Up @@ -676,7 +679,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
if not self._checkpoint_name:
self._checkpoint_name = f"model-{self.experiment.id}"
artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
artifact.add_file(p, name="model.ckpt")
artifact.add_file(p, name="model.ckpt", policy=self.add_file_policy)
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
self.experiment.log_artifact(artifact, aliases=aliases)
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
Expand Down
Loading