Skip to content

Commit 1b0bf68

Browse files
committed
fix(mlflow): Enable multiple checkpoint callbacks for MLflow logging
- Changed _checkpoint_callback to _checkpoint_callbacks list to support multiple ModelCheckpoint instances - Fixed logic in after_save_checkpoint() to properly append all callbacks - Updated finalize() to log checkpoints from all registered callbacks - Added test for multiple checkpoint callbacks with different monitored metrics Fixes the issue where only one checkpoint callback could be tracked when using log_model=True
1 parent 79ffe50 commit 1b0bf68

File tree

2 files changed

+84
-5
lines changed

2 files changed

+84
-5
lines changed

src/lightning/pytorch/loggers/mlflow.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(
142142
self.tags = tags
143143
self._log_model = log_model
144144
self._logged_model_time: dict[str, float] = {}
145-
self._checkpoint_callback: Optional[ModelCheckpoint] = None
145+
self._checkpoint_callbacks: list[ModelCheckpoint] = []
146146
self._prefix = prefix
147147
self._artifact_location = artifact_location
148148
self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous}
@@ -283,8 +283,9 @@ def finalize(self, status: str = "success") -> None:
283283
status = "FINISHED"
284284

285285
# log checkpoints as artifacts
286-
if self._checkpoint_callback:
287-
self._scan_and_log_checkpoints(self._checkpoint_callback)
286+
if self._checkpoint_callbacks:
287+
for callback in self._checkpoint_callbacks:
288+
self._scan_and_log_checkpoints(callback)
288289

289290
if self.experiment.get_run(self.run_id):
290291
self.experiment.set_terminated(self.run_id, status)
@@ -330,8 +331,11 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
330331
# log checkpoints as artifacts
331332
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
332333
self._scan_and_log_checkpoints(checkpoint_callback)
333-
elif self._log_model is True:
334-
self._checkpoint_callback = checkpoint_callback
334+
elif (
335+
self._log_model is True
336+
and checkpoint_callback not in self._checkpoint_callbacks
337+
):
338+
self._checkpoint_callbacks.append(checkpoint_callback)
335339

336340
def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
337341
# get checkpoints to be saved with associated score

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from typing import Any
1516
from unittest import mock
1617
from unittest.mock import MagicMock, Mock
1718

1819
import pytest
1920

2021
from lightning.pytorch import Trainer
22+
from lightning.pytorch.callbacks import ModelCheckpoint
2123
from lightning.pytorch.demos.boring_classes import BoringModel
2224
from lightning.pytorch.loggers.mlflow import (
2325
_MLFLOW_AVAILABLE,
2426
MLFlowLogger,
2527
_get_resolve_tags,
2628
)
29+
from lightning.pytorch.utilities.types import STEP_OUTPUT
2730

2831

2932
def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None):
@@ -427,3 +430,75 @@ def test_set_tracking_uri(mlflow_mock):
427430
mlflow_mock.set_tracking_uri.assert_not_called()
428431
_ = logger.experiment
429432
mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri")
433+
434+
435+
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
436+
def test_mlflow_multiple_checkpoints_top_k(mlflow_mock, tmp_path):
437+
"""Test that multiple ModelCheckpoint callbacks with top_k parameters work correctly with MLFlowLogger.
438+
439+
This test verifies that when using multiple ModelCheckpoint callbacks with save_top_k, both callbacks function
440+
correctly and save the expected number of checkpoints when using MLFlowLogger with log_model=True.
441+
442+
"""
443+
444+
class CustomBoringModel(BoringModel):
445+
def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
446+
loss = self.step(batch)
447+
self.log("train_loss", loss)
448+
return {"loss": loss}
449+
450+
def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
451+
loss = self.step(batch)
452+
self.log("val_loss", loss)
453+
return {"loss": loss}
454+
455+
client = mlflow_mock.tracking.MlflowClient
456+
457+
model = CustomBoringModel()
458+
logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model=True)
459+
logger = mock_mlflow_run_creation(logger, experiment_id="test-id")
460+
461+
# Create two ModelCheckpoint callbacks monitoring different metrics
462+
train_ckpt = ModelCheckpoint(
463+
dirpath=str(tmp_path / "train_checkpoints"),
464+
monitor="train_loss",
465+
filename="best_train_model-{epoch:02d}-{train_loss:.2f}",
466+
save_top_k=2,
467+
mode="min",
468+
)
469+
val_ckpt = ModelCheckpoint(
470+
dirpath=str(tmp_path / "val_checkpoints"),
471+
monitor="val_loss",
472+
filename="best_val_model-{epoch:02d}-{val_loss:.2f}",
473+
save_top_k=2,
474+
mode="min",
475+
)
476+
477+
# Create trainer with both callbacks
478+
trainer = Trainer(
479+
default_root_dir=tmp_path,
480+
logger=logger,
481+
callbacks=[train_ckpt, val_ckpt],
482+
max_epochs=5,
483+
limit_train_batches=3,
484+
limit_val_batches=3,
485+
)
486+
trainer.fit(model)
487+
488+
# Verify both callbacks saved their checkpoints
489+
assert len(train_ckpt.best_k_models) > 0, "Train checkpoint callback did not save any models"
490+
assert len(val_ckpt.best_k_models) > 0, "Validation checkpoint callback did not save any models"
491+
492+
# Get all artifact paths that were logged
493+
logged_artifacts = [call_args[0][1] for call_args in client.return_value.log_artifact.call_args_list]
494+
495+
# Verify MLFlow logged artifacts from both callbacks
496+
train_artifacts = [path for path in logged_artifacts if "train_checkpoints" in path]
497+
val_artifacts = [path for path in logged_artifacts if "val_checkpoints" in path]
498+
499+
assert len(train_artifacts) > 0, "MLFlow did not log any train checkpoint artifacts"
500+
assert len(val_artifacts) > 0, "MLFlow did not log any validation checkpoint artifacts"
501+
502+
# Verify the number of logged artifacts matches the save_top_k for each callback
503+
assert len(train_artifacts) == train_ckpt.save_top_k, "Number of logged train artifacts doesn't match save_top_k"
504+
assert len(val_artifacts) == val_ckpt.save_top_k, "Number of logged val artifacts doesn't match save_top_k"

0 commit comments

Comments
 (0)