@@ -427,3 +427,35 @@ def test_set_tracking_uri(mlflow_mock):
427
427
mlflow_mock .set_tracking_uri .assert_not_called ()
428
428
_ = logger .experiment
429
429
mlflow_mock .set_tracking_uri .assert_called_with ("the_tracking_uri" )
430
+
431
+ def test_mlflowlogger_metric_deduplication (monkeypatch ):
432
+ import types
433
+ from lightning .pytorch .loggers .mlflow import MLFlowLogger
434
+
435
+ # Dummy MLflow client to record log_batch calls
436
+ logged_metrics = []
437
+ class DummyMlflowClient :
438
+ def log_batch (self , run_id , metrics , ** kwargs ):
439
+ logged_metrics .extend (metrics )
440
+ def set_tracking_uri (self , uri ): pass
441
+ def create_run (self , experiment_id , tags ):
442
+ class Run : info = types .SimpleNamespace (run_id = "dummy_run_id" )
443
+ return Run ()
444
+ def get_run (self , run_id ):
445
+ class Run : info = types .SimpleNamespace (experiment_id = "dummy_experiment_id" )
446
+ return Run ()
447
+ def get_experiment_by_name (self , name ): return None
448
+ def create_experiment (self , name , artifact_location = None ): return "dummy_experiment_id"
449
+
450
+ # Patch the MLFlowLogger to use DummyMlflowClient
451
+ monkeypatch .setattr ("mlflow.tracking.MlflowClient" , lambda * a , ** k : DummyMlflowClient ())
452
+
453
+ logger = MLFlowLogger (experiment_name = "test_exp" )
454
+ logger .log_metrics ({'foo' : 1.0 }, step = 5 )
455
+ logger .log_metrics ({'foo' : 1.0 }, step = 5 ) # duplicate
456
+
457
+ # Only the first metric should be logged
458
+ assert len (logged_metrics ) == 1
459
+ assert logged_metrics [0 ].key == "foo"
460
+ assert logged_metrics [0 ].value == 1.0
461
+ assert logged_metrics [0 ].step == 5
0 commit comments