Skip to content

Commit fd4bafa

Browse files
Update test_mlflow.py
1 parent 7d786af commit fd4bafa

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,35 @@ def test_set_tracking_uri(mlflow_mock):
427427
mlflow_mock.set_tracking_uri.assert_not_called()
428428
_ = logger.experiment
429429
mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri")
430+
431+
def test_mlflowlogger_metric_deduplication(monkeypatch):
432+
import types
433+
from lightning.pytorch.loggers.mlflow import MLFlowLogger
434+
435+
# Dummy MLflow client to record log_batch calls
436+
logged_metrics = []
437+
class DummyMlflowClient:
438+
def log_batch(self, run_id, metrics, **kwargs):
439+
logged_metrics.extend(metrics)
440+
def set_tracking_uri(self, uri): pass
441+
def create_run(self, experiment_id, tags):
442+
class Run: info = types.SimpleNamespace(run_id="dummy_run_id")
443+
return Run()
444+
def get_run(self, run_id):
445+
class Run: info = types.SimpleNamespace(experiment_id="dummy_experiment_id")
446+
return Run()
447+
def get_experiment_by_name(self, name): return None
448+
def create_experiment(self, name, artifact_location=None): return "dummy_experiment_id"
449+
450+
# Patch the MLFlowLogger to use DummyMlflowClient
451+
monkeypatch.setattr("mlflow.tracking.MlflowClient", lambda *a, **k: DummyMlflowClient())
452+
453+
logger = MLFlowLogger(experiment_name="test_exp")
454+
logger.log_metrics({'foo': 1.0}, step=5)
455+
logger.log_metrics({'foo': 1.0}, step=5) # duplicate
456+
457+
# Only the first metric should be logged
458+
assert len(logged_metrics) == 1
459+
assert logged_metrics[0].key == "foo"
460+
assert logged_metrics[0].value == 1.0
461+
assert logged_metrics[0].step == 5

0 commit comments

Comments
 (0)