Skip to content

Commit 072a5cf

Browse files
committed
fix(mlflow): Enabling multiple callbacks for checkpoint reporting
1 parent ea59e40 commit 072a5cf

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/lightning/pytorch/loggers/mlflow.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(
142142
self.tags = tags
143143
self._log_model = log_model
144144
self._logged_model_time: dict[str, float] = {}
145-
self._checkpoint_callback: Optional[ModelCheckpoint] = None
145+
self._checkpoint_callbacks: list[ModelCheckpoint] = []
146146
self._prefix = prefix
147147
self._artifact_location = artifact_location
148148
self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
@@ -283,8 +283,9 @@ def finalize(self, status: str = "success") -> None:
283283
status = "FINISHED"
284284

285285
# log checkpoints as artifacts
286-
if self._checkpoint_callback:
287-
self._scan_and_log_checkpoints(self._checkpoint_callback)
286+
if self._checkpoint_callbacks:
287+
for callback in self._checkpoint_callbacks:
288+
self._scan_and_log_checkpoints(callback)
288289

289290
if self.experiment.get_run(self.run_id):
290291
self.experiment.set_terminated(self.run_id, status)
@@ -331,7 +332,8 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
331332
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
332333
self._scan_and_log_checkpoints(checkpoint_callback)
333334
elif self._log_model is True:
334-
self._checkpoint_callback = checkpoint_callback
335+
if checkpoint_callback not in self._checkpoint_callbacks:
336+
self._checkpoint_callbacks.append(checkpoint_callback)
335337

336338
def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
337339
# get checkpoints to be saved with associated score

0 commit comments

Comments
 (0)