diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 5bef27192b127..37ca362fa40c1 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -278,6 +278,7 @@ def any_lightning_module_function_or_hook(self): prefix: A string to put at the beginning of metric keys. experiment: WandB experiment object. Automatically set when creating a run. checkpoint_name: Name of the model checkpoint artifact being logged. + add_file_policy: If "mutable", copies file to tempdirectory before upload. \**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc. Raises: @@ -304,6 +305,7 @@ def __init__( experiment: Union["Run", "RunDisabled", None] = None, prefix: str = "", checkpoint_name: Optional[str] = None, + add_file_policy: Literal["mutable", "immutable"] = "mutable", **kwargs: Any, ) -> None: if not _WANDB_AVAILABLE: @@ -323,6 +325,7 @@ def __init__( self._experiment = experiment self._logged_model_time: dict[str, float] = {} self._checkpoint_callbacks: dict[int, ModelCheckpoint] = {} + self.add_file_policy = add_file_policy # paths are processed as strings if save_dir is not None: @@ -676,7 +679,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non if not self._checkpoint_name: self._checkpoint_name = f"model-{self.experiment.id}" artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata) - artifact.add_file(p, name="model.ckpt") + artifact.add_file(p, name="model.ckpt", policy=self.add_file_policy) aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] self.experiment.log_artifact(artifact, aliases=aliases) # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)