Skip to content

Commit fe79be1

Browse files
cgebbeBorda
andauthored
feat: allow immutable file upload for wandb logger (#20193)
Co-authored-by: Christian Gebbe <> Co-authored-by: Jirka Borovec <[email protected]>
1 parent af2727a commit fe79be1

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/lightning/pytorch/loggers/wandb.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def any_lightning_module_function_or_hook(self):
278278
prefix: A string to put at the beginning of metric keys.
279279
experiment: WandB experiment object. Automatically set when creating a run.
280280
checkpoint_name: Name of the model checkpoint artifact being logged.
281+
add_file_policy: If "mutable", copies file to tempdirectory before upload.
281282
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
282283
283284
Raises:
@@ -304,6 +305,7 @@ def __init__(
304305
experiment: Union["Run", "RunDisabled", None] = None,
305306
prefix: str = "",
306307
checkpoint_name: Optional[str] = None,
308+
add_file_policy: Literal["mutable", "immutable"] = "mutable",
307309
**kwargs: Any,
308310
) -> None:
309311
if not _WANDB_AVAILABLE:
@@ -323,6 +325,7 @@ def __init__(
323325
self._experiment = experiment
324326
self._logged_model_time: dict[str, float] = {}
325327
self._checkpoint_callbacks: dict[int, ModelCheckpoint] = {}
328+
self.add_file_policy = add_file_policy
326329

327330
# paths are processed as strings
328331
if save_dir is not None:
@@ -676,7 +679,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
676679
if not self._checkpoint_name:
677680
self._checkpoint_name = f"model-{self.experiment.id}"
678681
artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
679-
artifact.add_file(p, name="model.ckpt")
682+
artifact.add_file(p, name="model.ckpt", policy=self.add_file_policy)
680683
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
681684
self.experiment.log_artifact(artifact, aliases=aliases)
682685
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)

0 commit comments

Comments
 (0)