@@ -278,7 +278,6 @@ def any_lightning_module_function_or_hook(self):
278278 prefix: A string to put at the beginning of metric keys.
279279 experiment: WandB experiment object. Automatically set when creating a run.
280280 checkpoint_name: Name of the model checkpoint artifact being logged.
281- add_file_policy: If "mutable", copies file to tempdirectory before upload.
282281 \**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc.
283282
284283 Raises:
@@ -305,7 +304,6 @@ def __init__(
305304 experiment : Union ["Run" , "RunDisabled" , None ] = None ,
306305 prefix : str = "" ,
307306 checkpoint_name : Optional [str ] = None ,
308- add_file_policy : Literal ["mutable" , "immutable" ] = "mutable" ,
309307 ** kwargs : Any ,
310308 ) -> None :
311309 if not _WANDB_AVAILABLE :
@@ -324,8 +322,7 @@ def __init__(
324322 self ._prefix = prefix
325323 self ._experiment = experiment
326324 self ._logged_model_time : dict [str , float ] = {}
327- self ._checkpoint_callbacks : dict [int , ModelCheckpoint ] = {}
328- self .add_file_policy = add_file_policy
325+ self ._checkpoint_callback : Optional [ModelCheckpoint ] = None
329326
330327 # paths are processed as strings
331328 if save_dir is not None :
@@ -594,7 +591,7 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
594591 if self ._log_model == "all" or self ._log_model is True and checkpoint_callback .save_top_k == - 1 :
595592 self ._scan_and_log_checkpoints (checkpoint_callback )
596593 elif self ._log_model is True :
597- self ._checkpoint_callbacks [ id ( checkpoint_callback )] = checkpoint_callback
594+ self ._checkpoint_callback = checkpoint_callback
598595
599596 @staticmethod
600597 @rank_zero_only
@@ -647,9 +644,8 @@ def finalize(self, status: str) -> None:
647644 # Currently, checkpoints only get logged on success
648645 return
649646 # log checkpoints as artifacts
650- if self ._experiment is not None :
651- for checkpoint_callback in self ._checkpoint_callbacks .values ():
652- self ._scan_and_log_checkpoints (checkpoint_callback )
647+ if self ._checkpoint_callback and self ._experiment is not None :
648+ self ._scan_and_log_checkpoints (self ._checkpoint_callback )
653649
654650 def _scan_and_log_checkpoints (self , checkpoint_callback : ModelCheckpoint ) -> None :
655651 import wandb
@@ -679,7 +675,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non
679675 if not self ._checkpoint_name :
680676 self ._checkpoint_name = f"model-{ self .experiment .id } "
681677 artifact = wandb .Artifact (name = self ._checkpoint_name , type = "model" , metadata = metadata )
682- artifact .add_file (p , name = "model.ckpt" , policy = self . add_file_policy )
678+ artifact .add_file (p , name = "model.ckpt" )
683679 aliases = ["latest" , "best" ] if p == checkpoint_callback .best_model_path else ["latest" ]
684680 self .experiment .log_artifact (artifact , aliases = aliases )
685681 # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
0 commit comments