|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import os |
| 15 | +from typing import Any |
15 | 16 | from unittest import mock |
16 | 17 | from unittest.mock import MagicMock, Mock |
17 | 18 |
|
18 | 19 | import pytest |
19 | 20 |
|
20 | 21 | from lightning.pytorch import Trainer |
| 22 | +from lightning.pytorch.callbacks import ModelCheckpoint |
21 | 23 | from lightning.pytorch.demos.boring_classes import BoringModel |
22 | 24 | from lightning.pytorch.loggers.mlflow import ( |
23 | 25 | _MLFLOW_AVAILABLE, |
24 | 26 | MLFlowLogger, |
25 | 27 | _get_resolve_tags, |
26 | 28 | ) |
| 29 | +from lightning.pytorch.utilities.types import STEP_OUTPUT |
27 | 30 |
|
28 | 31 |
|
29 | 32 | def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, run_id=None): |
@@ -427,3 +430,75 @@ def test_set_tracking_uri(mlflow_mock): |
427 | 430 | mlflow_mock.set_tracking_uri.assert_not_called() |
428 | 431 | _ = logger.experiment |
429 | 432 | mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri") |
| 433 | + |
| 434 | + |
| 435 | +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) |
| 436 | +def test_mlflow_multiple_checkpoints_top_k(mlflow_mock, tmp_path): |
| 437 | + """Test that multiple ModelCheckpoint callbacks with top_k parameters work correctly with MLFlowLogger. |
| 438 | +
|
| 439 | + This test verifies that when using multiple ModelCheckpoint callbacks with save_top_k, both callbacks function |
| 440 | + correctly and save the expected number of checkpoints when using MLFlowLogger with log_model=True. |
| 441 | +
|
| 442 | + """ |
| 443 | + |
| 444 | + class CustomBoringModel(BoringModel): |
| 445 | + def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: |
| 446 | + loss = self.step(batch) |
| 447 | + self.log("train_loss", loss) |
| 448 | + return {"loss": loss} |
| 449 | + |
| 450 | + def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: |
| 451 | + loss = self.step(batch) |
| 452 | + self.log("val_loss", loss) |
| 453 | + return {"loss": loss} |
| 454 | + |
| 455 | + client = mlflow_mock.tracking.MlflowClient |
| 456 | + |
| 457 | + model = CustomBoringModel() |
| 458 | + logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model=True) |
| 459 | + logger = mock_mlflow_run_creation(logger, experiment_id="test-id") |
| 460 | + |
| 461 | + # Create two ModelCheckpoint callbacks monitoring different metrics |
| 462 | + train_ckpt = ModelCheckpoint( |
| 463 | + dirpath=str(tmp_path / "train_checkpoints"), |
| 464 | + monitor="train_loss", |
| 465 | + filename="best_train_model-{epoch:02d}-{train_loss:.2f}", |
| 466 | + save_top_k=2, |
| 467 | + mode="min", |
| 468 | + ) |
| 469 | + val_ckpt = ModelCheckpoint( |
| 470 | + dirpath=str(tmp_path / "val_checkpoints"), |
| 471 | + monitor="val_loss", |
| 472 | + filename="best_val_model-{epoch:02d}-{val_loss:.2f}", |
| 473 | + save_top_k=2, |
| 474 | + mode="min", |
| 475 | + ) |
| 476 | + |
| 477 | + # Create trainer with both callbacks |
| 478 | + trainer = Trainer( |
| 479 | + default_root_dir=tmp_path, |
| 480 | + logger=logger, |
| 481 | + callbacks=[train_ckpt, val_ckpt], |
| 482 | + max_epochs=5, |
| 483 | + limit_train_batches=3, |
| 484 | + limit_val_batches=3, |
| 485 | + ) |
| 486 | + trainer.fit(model) |
| 487 | + |
| 488 | + # Verify both callbacks saved their checkpoints |
| 489 | + assert len(train_ckpt.best_k_models) > 0, "Train checkpoint callback did not save any models" |
| 490 | + assert len(val_ckpt.best_k_models) > 0, "Validation checkpoint callback did not save any models" |
| 491 | + |
| 492 | + # Get all artifact paths that were logged |
| 493 | + logged_artifacts = [call_args[0][1] for call_args in client.return_value.log_artifact.call_args_list] |
| 494 | + |
| 495 | + # Verify MLFlow logged artifacts from both callbacks |
| 496 | + train_artifacts = [path for path in logged_artifacts if "train_checkpoints" in path] |
| 497 | + val_artifacts = [path for path in logged_artifacts if "val_checkpoints" in path] |
| 498 | + |
| 499 | + assert len(train_artifacts) > 0, "MLFlow did not log any train checkpoint artifacts" |
| 500 | + assert len(val_artifacts) > 0, "MLFlow did not log any validation checkpoint artifacts" |
| 501 | + |
| 502 | + # Verify the number of logged artifacts matches the save_top_k for each callback |
| 503 | + assert len(train_artifacts) == train_ckpt.save_top_k, "Number of logged train artifacts doesn't match save_top_k" |
| 504 | + assert len(val_artifacts) == val_ckpt.save_top_k, "Number of logged val artifacts doesn't match save_top_k" |
0 commit comments