Skip to content

Commit d3e839c

Browse files
authored
Merge branch 'master' into patch-2
2 parents 20856b5 + fe79be1 commit d3e839c

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/lightning/pytorch/loggers/wandb.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def any_lightning_module_function_or_hook(self):
278278
prefix: A string to put at the beginning of metric keys.
279279
experiment: WandB experiment object. Automatically set when creating a run.
280280
checkpoint_name: Name of the model checkpoint artifact being logged.
281+
add_file_policy: If "mutable", copies file to tempdirectory before upload.
281282
\**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
282283
283284
Raises:
@@ -304,6 +305,7 @@ def __init__(
304305
experiment: Union["Run", "RunDisabled", None] = None,
305306
prefix: str = "",
306307
checkpoint_name: Optional[str] = None,
308+
add_file_policy: Literal["mutable", "immutable"] = "mutable",
307309
**kwargs: Any,
308310
) -> None:
309311
if not _WANDB_AVAILABLE:
@@ -323,6 +325,7 @@ def __init__(
323325
self._experiment = experiment
324326
self._logged_model_time: dict[str, float] = {}
325327
self._checkpoint_callbacks: dict[int, ModelCheckpoint] = {}
328+
self.add_file_policy = add_file_policy
326329

327330
# paths are processed as strings
328331
if save_dir is not None:
@@ -676,7 +679,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
676679
if not self._checkpoint_name:
677680
self._checkpoint_name = f"model-{self.experiment.id}"
678681
artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
679-
artifact.add_file(p, name="model.ckpt")
682+
artifact.add_file(p, name="model.ckpt", policy=self.add_file_policy)
680683
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
681684
self.experiment.log_artifact(artifact, aliases=aliases)
682685
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)

src/lightning/pytorch/trainer/connectors/callback_connector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from datetime import timedelta
1919
from typing import Optional, Union
2020

21-
from lightning_utilities import module_available
21+
from lightning_utilities.core.imports import RequirementCache
2222

2323
import lightning.pytorch as pl
2424
from lightning.fabric.utilities.registry import _load_external_callbacks
@@ -93,7 +93,7 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None:
9393
" but found `ModelCheckpoint` in callbacks list."
9494
)
9595
elif enable_checkpointing:
96-
if module_available("litmodels") and self.trainer._model_registry:
96+
if RequirementCache("litmodels >=0.1.7") and self.trainer._model_registry:
9797
trainer_source = inspect.getmodule(self.trainer)
9898
if trainer_source is None or not isinstance(trainer_source.__package__, str):
9999
raise RuntimeError("Unable to determine the source of the trainer.")
@@ -103,11 +103,11 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None:
103103
else:
104104
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
105105

106-
model_checkpoint = LitModelCheckpoint(model_name=self.trainer._model_registry)
106+
model_checkpoint = LitModelCheckpoint(model_registry=self.trainer._model_registry)
107107
else:
108108
rank_zero_info(
109109
"You are using the default ModelCheckpoint callback."
110-
" Install `litmodels` package to use the `LitModelCheckpoint` instead"
110+
" Install `pip install litmodels` package to use the `LitModelCheckpoint` instead"
111111
" for seamless uploading to the Lightning model registry."
112112
)
113113
model_checkpoint = ModelCheckpoint()

0 commit comments

Comments
 (0)