Skip to content

Commit 0773eb4

Browse files
test_mlflow.py
1 parent d4754ad commit 0773eb4

File tree

1 file changed

+0
-32
lines changed

1 file changed

+0
-32
lines changed

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -427,35 +427,3 @@ def test_set_tracking_uri(mlflow_mock):
427427
mlflow_mock.set_tracking_uri.assert_not_called()
428428
_ = logger.experiment
429429
mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri")
430-
431-
def test_mlflowlogger_metric_deduplication(monkeypatch):
432-
import types
433-
from lightning.pytorch.loggers.mlflow import MLFlowLogger
434-
435-
# Dummy MLflow client to record log_batch calls
436-
logged_metrics = []
437-
class DummyMlflowClient:
438-
def log_batch(self, run_id, metrics, **kwargs):
439-
logged_metrics.extend(metrics)
440-
def set_tracking_uri(self, uri): pass
441-
def create_run(self, experiment_id, tags):
442-
class Run: info = types.SimpleNamespace(run_id="dummy_run_id")
443-
return Run()
444-
def get_run(self, run_id):
445-
class Run: info = types.SimpleNamespace(experiment_id="dummy_experiment_id")
446-
return Run()
447-
def get_experiment_by_name(self, name): return None
448-
def create_experiment(self, name, artifact_location=None): return "dummy_experiment_id"
449-
450-
# Patch the MLFlowLogger to use DummyMlflowClient
451-
monkeypatch.setattr("mlflow.tracking.MlflowClient", lambda *a, **k: DummyMlflowClient())
452-
453-
logger = MLFlowLogger(experiment_name="test_exp")
454-
logger.log_metrics({'foo': 1.0}, step=5)
455-
logger.log_metrics({'foo': 1.0}, step=5) # duplicate
456-
457-
# Only the first metric should be logged
458-
assert len(logged_metrics) == 1
459-
assert logged_metrics[0].key == "foo"
460-
assert logged_metrics[0].value == 1.0
461-
assert logged_metrics[0].step == 5

0 commit comments

Comments
 (0)