From c707bb8b9d0e516061807597bba4228e79975118 Mon Sep 17 00:00:00 2001 From: Bharat Ramanathan Date: Mon, 7 Apr 2025 12:48:16 +0530 Subject: [PATCH] feat: add support for cloud files and non-local checkpoints --- src/lightning/pytorch/loggers/wandb.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 5bef27192b127..9018d00dbe45c 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -27,6 +27,7 @@ from torch import Tensor from typing_extensions import override +from lightning.fabric.utilities.cloud_io import _is_local_file_protocol from lightning.fabric.utilities.logger import ( _add_prefix, _convert_json_serializable, @@ -258,7 +259,9 @@ def any_lightning_module_function_or_hook(self): Args: name: Display name for the run. - save_dir: Path where data is saved. + save_dir: Path where data is saved. Can be: + - A local directory path + - A remote storage path (S3, GCS, Azure Blob Storage) version: Sets the version, mainly used to resume a previous run. offline: Run offline (data can be streamed later to wandb servers). dir: Same as save_dir. @@ -348,6 +351,7 @@ def __init__( self._name = self._wandb_init.get("name") self._id = self._wandb_init.get("id") self._checkpoint_name = checkpoint_name + self._is_local = _is_local_file_protocol(save_dir) def __getstate__(self) -> dict[str, Any]: import wandb @@ -676,7 +680,12 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non if not self._checkpoint_name: self._checkpoint_name = f"model-{self.experiment.id}" artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata) - artifact.add_file(p, name="model.ckpt") + + if not self._is_local: + artifact.add_reference(p, name="model.ckpt") + else: + artifact.add_file(p, name="model.ckpt") + aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] self.experiment.log_artifact(artifact, aliases=aliases) # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)