Skip to content

Commit b81b5a0

Browse files
committed
test: MLFlowLogger(log_model=True)
1 parent 85c01c6 commit b81b5a0

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def test_mlflow_run_name_setting(tmp_path):
9595
if not _MLFLOW_AVAILABLE:
9696
pytest.skip("test for explicit file creation requires mlflow dependency to be installed.")
9797

98+
import mlflow
9899
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
99100

100101
resolve_tags = _get_resolve_tags()
@@ -121,6 +122,8 @@ def test_mlflow_run_name_setting(tmp_path):
121122
default_tags = resolve_tags(None)
122123
client.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags)
123124

125+
mlflow.set_tracking_uri(None)
126+
124127

125128
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
126129
def test_mlflow_run_id_setting(mlflow_mock, tmp_path):
@@ -177,8 +180,10 @@ def test_mlflow_logger_dirs_creation(tmp_path):
177180
if not _MLFLOW_AVAILABLE:
178181
pytest.skip("test for explicit file creation requires mlflow dependency to be installed.")
179182

183+
import mlflow
184+
180185
assert not os.listdir(tmp_path)
181-
logger = MLFlowLogger("test", save_dir=str(tmp_path))
186+
logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model=True)
182187
assert logger.save_dir == str(tmp_path)
183188
assert set(os.listdir(tmp_path)) == {".trash"}
184189
run_id = logger.run_id
@@ -208,7 +213,12 @@ def on_train_epoch_end(self, *args, **kwargs):
208213
assert "epoch" in os.listdir(tmp_path / exp_id / run_id / "metrics")
209214
assert set(os.listdir(tmp_path / exp_id / run_id / "params")) == model.hparams.keys()
210215
assert trainer.checkpoint_callback.dirpath == str(tmp_path / exp_id / run_id / "checkpoints")
211-
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"epoch=0-step={limit_batches}.ckpt"]
216+
ckpt_stem = f"epoch=0-step={limit_batches}"
217+
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f"{ckpt_stem}.ckpt"]
218+
artifacts = mlflow.artifacts.list_artifacts(run_id=logger.run_id)
219+
assert [file_info.path for file_info in artifacts] == [ckpt_stem]
220+
221+
mlflow.set_tracking_uri(None)
212222

213223

214224
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())

0 commit comments

Comments
 (0)