diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index 02396d8021633..2c12c942f701e 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -363,7 +363,7 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] # Artifact path on mlflow - artifact_path = Path(self._checkpoint_path_prefix) / Path(p).stem + artifact_path = Path(self._checkpoint_path_prefix, Path(p).stem).as_posix() # Log the checkpoint self.experiment.log_artifact(self._run_id, p, artifact_path) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 8118349ea6721..0800ce16de332 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -95,6 +95,7 @@ def test_mlflow_run_name_setting(tmp_path): if not _MLFLOW_AVAILABLE: pytest.skip("test for explicit file creation requires mlflow dependency to be installed.") + import mlflow from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME resolve_tags = _get_resolve_tags() @@ -121,6 +122,8 @@ def test_mlflow_run_name_setting(tmp_path): default_tags = resolve_tags(None) client.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags) + mlflow.set_tracking_uri(None) + @mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_run_id_setting(mlflow_mock, tmp_path): @@ -177,8 +180,10 @@ def test_mlflow_logger_dirs_creation(tmp_path): if not _MLFLOW_AVAILABLE: pytest.skip("test for explicit file creation requires mlflow dependency to be installed.") + import mlflow + assert not os.listdir(tmp_path) - logger = MLFlowLogger("test", save_dir=str(tmp_path)) + logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model=True) assert logger.save_dir == str(tmp_path) assert set(os.listdir(tmp_path)) == {".trash"} run_id = logger.run_id @@ -208,7 +213,12 @@ def on_train_epoch_end(self, *args, **kwargs): assert "epoch" in os.listdir(tmp_path / exp_id / run_id / "metrics") assert set(os.listdir(tmp_path / exp_id / run_id / "params")) == model.hparams.keys() assert trainer.checkpoint_callback.dirpath == str(tmp_path / exp_id / run_id / "checkpoints") - assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"] + ckpt_stem = f"epoch=0-step={limit_batches}" + assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"{ckpt_stem}.ckpt"] + artifacts = mlflow.artifacts.list_artifacts(run_id=logger.run_id) + assert [file_info.path for file_info in artifacts] == [ckpt_stem] + + mlflow.set_tracking_uri(None) @mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())