@@ -427,35 +427,3 @@ def test_set_tracking_uri(mlflow_mock):
427
427
mlflow_mock .set_tracking_uri .assert_not_called ()
428
428
_ = logger .experiment
429
429
mlflow_mock .set_tracking_uri .assert_called_with ("the_tracking_uri" )
430
-
431
- def test_mlflowlogger_metric_deduplication (monkeypatch ):
432
- import types
433
- from lightning .pytorch .loggers .mlflow import MLFlowLogger
434
-
435
- # Dummy MLflow client to record log_batch calls
436
- logged_metrics = []
437
- class DummyMlflowClient :
438
- def log_batch (self , run_id , metrics , ** kwargs ):
439
- logged_metrics .extend (metrics )
440
- def set_tracking_uri (self , uri ): pass
441
- def create_run (self , experiment_id , tags ):
442
- class Run : info = types .SimpleNamespace (run_id = "dummy_run_id" )
443
- return Run ()
444
- def get_run (self , run_id ):
445
- class Run : info = types .SimpleNamespace (experiment_id = "dummy_experiment_id" )
446
- return Run ()
447
- def get_experiment_by_name (self , name ): return None
448
- def create_experiment (self , name , artifact_location = None ): return "dummy_experiment_id"
449
-
450
- # Patch the MLFlowLogger to use DummyMlflowClient
451
- monkeypatch .setattr ("mlflow.tracking.MlflowClient" , lambda * a , ** k : DummyMlflowClient ())
452
-
453
- logger = MLFlowLogger (experiment_name = "test_exp" )
454
- logger .log_metrics ({'foo' : 1.0 }, step = 5 )
455
- logger .log_metrics ({'foo' : 1.0 }, step = 5 ) # duplicate
456
-
457
- # Only the first metric should be logged
458
- assert len (logged_metrics ) == 1
459
- assert logged_metrics [0 ].key == "foo"
460
- assert logged_metrics [0 ].value == 1.0
461
- assert logged_metrics [0 ].step == 5
0 commit comments