@@ -142,7 +142,7 @@ def __init__(
142142 self .tags = tags
143143 self ._log_model = log_model
144144 self ._logged_model_time : dict [str , float ] = {}
145- self ._checkpoint_callback : Optional [ModelCheckpoint ] = None
145+ self ._checkpoint_callbacks : list [ModelCheckpoint ] = []
146146 self ._prefix = prefix
147147 self ._artifact_location = artifact_location
148148 self ._log_batch_kwargs = {} if synchronous is None else {"synchronous" : synchronous }
@@ -283,8 +283,9 @@ def finalize(self, status: str = "success") -> None:
283283 status = "FINISHED"
284284
285285 # log checkpoints as artifacts
286- if self ._checkpoint_callback :
287- self ._scan_and_log_checkpoints (self ._checkpoint_callback )
286+ if self ._checkpoint_callbacks :
287+ for callback in self ._checkpoint_callbacks :
288+ self ._scan_and_log_checkpoints (callback )
288289
289290 if self .experiment .get_run (self .run_id ):
290291 self .experiment .set_terminated (self .run_id , status )
@@ -331,7 +332,8 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
331332 if self ._log_model == "all" or self ._log_model is True and checkpoint_callback .save_top_k == - 1 :
332333 self ._scan_and_log_checkpoints (checkpoint_callback )
333334 elif self ._log_model is True :
334- self ._checkpoint_callback = checkpoint_callback
335+ if checkpoint_callback not in self ._checkpoint_callbacks :
336+ self ._checkpoint_callbacks .append (checkpoint_callback )
335337
336338 def _scan_and_log_checkpoints (self , checkpoint_callback : ModelCheckpoint ) -> None :
337339 # get checkpoints to be saved with associated score
0 commit comments