@@ -95,6 +95,7 @@ def test_mlflow_run_name_setting(tmp_path):
95
95
if not _MLFLOW_AVAILABLE :
96
96
pytest .skip ("test for explicit file creation requires mlflow dependency to be installed." )
97
97
98
+ import mlflow
98
99
from mlflow .utils .mlflow_tags import MLFLOW_RUN_NAME
99
100
100
101
resolve_tags = _get_resolve_tags ()
@@ -121,6 +122,8 @@ def test_mlflow_run_name_setting(tmp_path):
121
122
default_tags = resolve_tags (None )
122
123
client .create_run .assert_called_with (experiment_id = "exp-id" , tags = default_tags )
123
124
125
+ mlflow .set_tracking_uri (None )
126
+
124
127
125
128
@mock .patch ("lightning.pytorch.loggers.mlflow._get_resolve_tags" , Mock ())
126
129
def test_mlflow_run_id_setting (mlflow_mock , tmp_path ):
@@ -177,8 +180,10 @@ def test_mlflow_logger_dirs_creation(tmp_path):
177
180
if not _MLFLOW_AVAILABLE :
178
181
pytest .skip ("test for explicit file creation requires mlflow dependency to be installed." )
179
182
183
+ import mlflow
184
+
180
185
assert not os .listdir (tmp_path )
181
- logger = MLFlowLogger ("test" , save_dir = str (tmp_path ))
186
+ logger = MLFlowLogger ("test" , save_dir = str (tmp_path ), log_model = True )
182
187
assert logger .save_dir == str (tmp_path )
183
188
assert set (os .listdir (tmp_path )) == {".trash" }
184
189
run_id = logger .run_id
@@ -208,7 +213,12 @@ def on_train_epoch_end(self, *args, **kwargs):
208
213
assert "epoch" in os .listdir (tmp_path / exp_id / run_id / "metrics" )
209
214
assert set (os .listdir (tmp_path / exp_id / run_id / "params" )) == model .hparams .keys ()
210
215
assert trainer .checkpoint_callback .dirpath == str (tmp_path / exp_id / run_id / "checkpoints" )
211
- assert os .listdir (trainer .checkpoint_callback .dirpath ) == [f"epoch=0-step={ limit_batches } .ckpt" ]
216
+ ckpt_stem = f"epoch=0-step={ limit_batches } "
217
+ assert os .listdir (trainer .checkpoint_callback .dirpath ) == [f"{ ckpt_stem } .ckpt" ]
218
+ artifacts = mlflow .artifacts .list_artifacts (run_id = logger .run_id )
219
+ assert [file_info .path for file_info in artifacts ] == [ckpt_stem ]
220
+
221
+ mlflow .set_tracking_uri (None )
212
222
213
223
214
224
@mock .patch ("lightning.pytorch.loggers.mlflow._get_resolve_tags" , Mock ())
0 commit comments