Skip to content

Commit 9b315f1

Browse files
committed
test(mlflow): Added test to test that multiple callbacks are picked up
1 parent 072a5cf commit 9b315f1

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from typing import Any
1516
from unittest import mock
1617
from unittest.mock import MagicMock, Mock
1718

1819
import pytest
1920

2021
from lightning.pytorch import Trainer
22+
from lightning.pytorch.callbacks import ModelCheckpoint
23+
from lightning.pytorch.utilities.types import STEP_OUTPUT
2124
from lightning.pytorch.demos.boring_classes import BoringModel
2225
from lightning.pytorch.loggers.mlflow import (
2326
_MLFLOW_AVAILABLE,
@@ -427,3 +430,79 @@ def test_set_tracking_uri(mlflow_mock):
427430
mlflow_mock.set_tracking_uri.assert_not_called()
428431
_ = logger.experiment
429432
mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri")
433+
434+
435+
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
436+
def test_mlflow_multiple_checkpoints_top_k(mlflow_mock, tmp_path):
437+
"""Test that multiple ModelCheckpoint callbacks with top_k parameters work correctly with MLFlowLogger.
438+
439+
This test verifies that when using multiple ModelCheckpoint callbacks with save_top_k,
440+
both callbacks function correctly and save the expected number of checkpoints when using
441+
MLFlowLogger with log_model=True.
442+
"""
443+
444+
class CustomBoringModel(BoringModel):
445+
def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
446+
loss = self.step(batch)
447+
self.log("train_loss", loss)
448+
return {"loss": loss}
449+
450+
def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
451+
loss = self.step(batch)
452+
self.log("val_loss", loss)
453+
return {"loss": loss}
454+
455+
client = mlflow_mock.tracking.MlflowClient
456+
457+
model = CustomBoringModel()
458+
logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model=True)
459+
logger = mock_mlflow_run_creation(logger, experiment_id="test-id")
460+
461+
# Create two ModelCheckpoint callbacks monitoring different metrics
462+
train_ckpt = ModelCheckpoint(
463+
dirpath=str(tmp_path / "train_checkpoints"),
464+
monitor="train_loss",
465+
filename="best_train_model-{epoch:02d}-{train_loss:.2f}",
466+
save_top_k=2,
467+
mode="min",
468+
)
469+
val_ckpt = ModelCheckpoint(
470+
dirpath=str(tmp_path / "val_checkpoints"),
471+
monitor="val_loss",
472+
filename="best_val_model-{epoch:02d}-{val_loss:.2f}",
473+
save_top_k=2,
474+
mode="min",
475+
)
476+
477+
# Create trainer with both callbacks
478+
trainer = Trainer(
479+
default_root_dir=tmp_path,
480+
logger=logger,
481+
callbacks=[train_ckpt, val_ckpt],
482+
max_epochs=5,
483+
limit_train_batches=3,
484+
limit_val_batches=3,
485+
)
486+
trainer.fit(model)
487+
488+
# Verify both callbacks saved their checkpoints
489+
assert len(train_ckpt.best_k_models) > 0, "Train checkpoint callback did not save any models"
490+
assert len(val_ckpt.best_k_models) > 0, "Validation checkpoint callback did not save any models"
491+
492+
# Get all artifact paths that were logged
493+
logged_artifacts = [call_args[0][1] for call_args in client.return_value.log_artifact.call_args_list]
494+
495+
# Verify MLFlow logged artifacts from both callbacks
496+
# Get all artifact paths that were logged
497+
logged_artifacts = [call_args[0][1] for call_args in client.return_value.log_artifact.call_args_list]
498+
499+
# Verify MLFlow logged artifacts from both callbacks
500+
train_artifacts = [path for path in logged_artifacts if "train_checkpoints" in path]
501+
val_artifacts = [path for path in logged_artifacts if "val_checkpoints" in path]
502+
503+
assert len(train_artifacts) > 0, "MLFlow did not log any train checkpoint artifacts"
504+
assert len(val_artifacts) > 0, "MLFlow did not log any validation checkpoint artifacts"
505+
506+
# Verify the number of logged artifacts matches the save_top_k for each callback
507+
assert len(train_artifacts) == train_ckpt.save_top_k, "Number of logged train artifacts doesn't match save_top_k"
508+
assert len(val_artifacts) == val_ckpt.save_top_k, "Number of logged val artifacts doesn't match save_top_k"

0 commit comments

Comments
 (0)