Skip to content

Commit dc49140

Browse files
Update mlflow.py
1 parent 6675932 commit dc49140

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

src/lightning/pytorch/loggers/mlflow.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ def any_lightning_module_function_or_hook(self):
109109
ModuleNotFoundError:
110110
If required MLFlow package is not installed on the device.
111111
112+
Note:
113+
As of vX.XX, MLFlowLogger will skip logging any metric (same name and step)
114+
more than once per run, to prevent database unique constraint violations on
115+
some MLflow backends (such as PostgreSQL). Only the first value for each (metric, step)
116+
pair will be logged per run. This improves robustness for all users.
117+
112118
"""
113119

114120
LOGGER_JOIN_CHAR = "-"
@@ -126,6 +132,7 @@ def __init__(
126132
run_id: Optional[str] = None,
127133
synchronous: Optional[bool] = None,
128134
):
135+
129136
if not _MLFLOW_AVAILABLE:
130137
raise ModuleNotFoundError(str(_MLFLOW_AVAILABLE))
131138
if synchronous is not None and not _MLFLOW_SYNCHRONOUS_AVAILABLE:
@@ -151,6 +158,7 @@ def __init__(
151158
from mlflow.tracking import MlflowClient
152159

153160
self._mlflow_client = MlflowClient(tracking_uri)
161+
self._logged_metrics = set() # Track (key, step)
154162

155163
@property
156164
@rank_zero_experiment
@@ -201,6 +209,7 @@ def experiment(self) -> "MlflowClient":
201209
resolve_tags = _get_resolve_tags()
202210
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags))
203211
self._run_id = run.info.run_id
212+
self._logged_metrics.clear()
204213
self._initialized = True
205214
return self._mlflow_client
206215

@@ -257,7 +266,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
257266
if isinstance(v, str):
258267
log.warning(f"Discarding metric with string value {k}={v}.")
259268
continue
260-
269+
261270
new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
262271
if k != new_k:
263272
rank_zero_warn(
@@ -266,8 +275,15 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
266275
category=RuntimeWarning,
267276
)
268277
k = new_k
278+
279+
metric_id = (k, step or 0)
280+
if metric_id in self._logged_metrics:
281+
continue
282+
self._logged_metrics.add(metric_id)
283+
269284
metrics_list.append(Metric(key=k, value=v, timestamp=timestamp_ms, step=step or 0))
270285

286+
271287
self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list, **self._log_batch_kwargs)
272288

273289
@override

0 commit comments

Comments
 (0)