@@ -428,31 +428,46 @@ def test_set_tracking_uri(mlflow_mock):
428428 _ = logger .experiment
429429 mlflow_mock .set_tracking_uri .assert_called_with ("the_tracking_uri" )
430430
431+
431432def test_mlflowlogger_metric_deduplication (monkeypatch ):
432433 import types
434+
433435 from lightning .pytorch .loggers .mlflow import MLFlowLogger
434436
435437 # Dummy MLflow client to record log_batch calls
436438 logged_metrics = []
439+
437440 class DummyMlflowClient :
438441 def log_batch (self , run_id , metrics , ** kwargs ):
439442 logged_metrics .extend (metrics )
440- def set_tracking_uri (self , uri ): pass
441- def create_run (self , experiment_id , tags ):
442- class Run : info = types .SimpleNamespace (run_id = "dummy_run_id" )
443+
444+ def set_tracking_uri (self , uri ):
445+ pass
446+
447+ def create_run (self , experiment_id , tags ):
448+ class Run :
449+ info = types .SimpleNamespace (run_id = "dummy_run_id" )
450+
443451 return Run ()
452+
444453 def get_run (self , run_id ):
445- class Run : info = types .SimpleNamespace (experiment_id = "dummy_experiment_id" )
454+ class Run :
455+ info = types .SimpleNamespace (experiment_id = "dummy_experiment_id" )
456+
446457 return Run ()
447- def get_experiment_by_name (self , name ): return None
448- def create_experiment (self , name , artifact_location = None ): return "dummy_experiment_id"
458+
459+ def get_experiment_by_name (self , name ):
460+ return None
461+
462+ def create_experiment (self , name , artifact_location = None ):
463+ return "dummy_experiment_id"
449464
450465 # Patch the MLFlowLogger to use DummyMlflowClient
451466 monkeypatch .setattr ("mlflow.tracking.MlflowClient" , lambda * a , ** k : DummyMlflowClient ())
452467
453468 logger = MLFlowLogger (experiment_name = "test_exp" )
454- logger .log_metrics ({' foo' : 1.0 }, step = 5 )
455- logger .log_metrics ({' foo' : 1.0 }, step = 5 ) # duplicate
469+ logger .log_metrics ({" foo" : 1.0 }, step = 5 )
470+ logger .log_metrics ({" foo" : 1.0 }, step = 5 ) # duplicate
456471
457472 # Only the first metric should be logged
458473 assert len (logged_metrics ) == 1
0 commit comments