Skip to content

Commit c707bb8

Browse files
committed
feat: add support for cloud files and non-local checkpoints
1 parent ca13f77 commit c707bb8

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/lightning/pytorch/loggers/wandb.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torch import Tensor
2828
from typing_extensions import override
2929

30+
from lightning.fabric.utilities.cloud_io import _is_local_file_protocol
3031
from lightning.fabric.utilities.logger import (
3132
_add_prefix,
3233
_convert_json_serializable,
@@ -258,7 +259,9 @@ def any_lightning_module_function_or_hook(self):
258259
259260
Args:
260261
name: Display name for the run.
261-
save_dir: Path where data is saved.
262+
save_dir: Path where data is saved. Can be:
263+
- A local directory path
264+
- A remote storage path (S3, GCS, Azure Blob Storage)
262265
version: Sets the version, mainly used to resume a previous run.
263266
offline: Run offline (data can be streamed later to wandb servers).
264267
dir: Same as save_dir.
@@ -348,6 +351,7 @@ def __init__(
348351
self._name = self._wandb_init.get("name")
349352
self._id = self._wandb_init.get("id")
350353
self._checkpoint_name = checkpoint_name
354+
self._is_local = _is_local_file_protocol(save_dir)
351355

352356
def __getstate__(self) -> dict[str, Any]:
353357
import wandb
@@ -676,7 +680,12 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
676680
if not self._checkpoint_name:
677681
self._checkpoint_name = f"model-{self.experiment.id}"
678682
artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata)
679-
artifact.add_file(p, name="model.ckpt")
683+
684+
if not self._is_local:
685+
artifact.add_reference(p, name="model.ckpt")
686+
else:
687+
artifact.add_file(p, name="model.ckpt")
688+
680689
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
681690
self.experiment.log_artifact(artifact, aliases=aliases)
682691
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)

0 commit comments

Comments
 (0)