|
27 | 27 | from torch import Tensor |
28 | 28 | from typing_extensions import override |
29 | 29 |
|
| 30 | +from lightning.fabric.utilities.cloud_io import _is_local_file_protocol |
30 | 31 | from lightning.fabric.utilities.logger import ( |
31 | 32 | _add_prefix, |
32 | 33 | _convert_json_serializable, |
@@ -258,7 +259,9 @@ def any_lightning_module_function_or_hook(self): |
258 | 259 |
|
259 | 260 | Args: |
260 | 261 | name: Display name for the run. |
261 | | - save_dir: Path where data is saved. |
| 262 | + save_dir: Path where data is saved. Can be: |
| 263 | + - A local directory path |
| 264 | + - A remote storage path (S3, GCS, Azure Blob Storage) |
262 | 265 | version: Sets the version, mainly used to resume a previous run. |
263 | 266 | offline: Run offline (data can be streamed later to wandb servers). |
264 | 267 | dir: Same as save_dir. |
@@ -348,6 +351,7 @@ def __init__( |
348 | 351 | self._name = self._wandb_init.get("name") |
349 | 352 | self._id = self._wandb_init.get("id") |
350 | 353 | self._checkpoint_name = checkpoint_name |
| 354 | + self._is_local = _is_local_file_protocol(save_dir) |
351 | 355 |
|
352 | 356 | def __getstate__(self) -> dict[str, Any]: |
353 | 357 | import wandb |
@@ -676,7 +680,12 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non |
676 | 680 | if not self._checkpoint_name: |
677 | 681 | self._checkpoint_name = f"model-{self.experiment.id}" |
678 | 682 | artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata) |
679 | | - artifact.add_file(p, name="model.ckpt") |
| 683 | + |
| 684 | + if not self._is_local: |
| 685 | + artifact.add_reference(p, name="model.ckpt") |
| 686 | + else: |
| 687 | + artifact.add_file(p, name="model.ckpt") |
| 688 | + |
680 | 689 | aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] |
681 | 690 | self.experiment.log_artifact(artifact, aliases=aliases) |
682 | 691 | # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) |
|
0 commit comments