@@ -109,6 +109,12 @@ def any_lightning_module_function_or_hook(self):
109
109
ModuleNotFoundError:
110
110
If required MLFlow package is not installed on the device.
111
111
112
+ Note:
113
+ As of vX.XX, MLFlowLogger will skip logging any metric (same name and step)
114
+ more than once per run, to prevent database unique constraint violations on
115
+ some MLflow backends (such as PostgreSQL). Only the first value for each (metric, step)
116
+ pair will be logged per run. This improves robustness for all users.
117
+
112
118
"""
113
119
114
120
LOGGER_JOIN_CHAR = "-"
@@ -126,6 +132,7 @@ def __init__(
126
132
run_id : Optional [str ] = None ,
127
133
synchronous : Optional [bool ] = None ,
128
134
):
135
+
129
136
if not _MLFLOW_AVAILABLE :
130
137
raise ModuleNotFoundError (str (_MLFLOW_AVAILABLE ))
131
138
if synchronous is not None and not _MLFLOW_SYNCHRONOUS_AVAILABLE :
@@ -151,6 +158,7 @@ def __init__(
151
158
from mlflow .tracking import MlflowClient
152
159
153
160
self ._mlflow_client = MlflowClient (tracking_uri )
161
+ self ._logged_metrics = set () # Track (key, step)
154
162
155
163
@property
156
164
@rank_zero_experiment
@@ -201,6 +209,7 @@ def experiment(self) -> "MlflowClient":
201
209
resolve_tags = _get_resolve_tags ()
202
210
run = self ._mlflow_client .create_run (experiment_id = self ._experiment_id , tags = resolve_tags (self .tags ))
203
211
self ._run_id = run .info .run_id
212
+ self ._logged_metrics .clear ()
204
213
self ._initialized = True
205
214
return self ._mlflow_client
206
215
@@ -257,7 +266,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
257
266
if isinstance (v , str ):
258
267
log .warning (f"Discarding metric with string value { k } ={ v } ." )
259
268
continue
260
-
269
+
261
270
new_k = re .sub ("[^a-zA-Z0-9_/. -]+" , "" , k )
262
271
if k != new_k :
263
272
rank_zero_warn (
@@ -266,8 +275,15 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
266
275
category = RuntimeWarning ,
267
276
)
268
277
k = new_k
278
+
279
+ metric_id = (k , step or 0 )
280
+ if metric_id in self ._logged_metrics :
281
+ continue
282
+ self ._logged_metrics .add (metric_id )
283
+
269
284
metrics_list .append (Metric (key = k , value = v , timestamp = timestamp_ms , step = step or 0 ))
270
285
286
+
271
287
self .experiment .log_batch (run_id = self .run_id , metrics = metrics_list , ** self ._log_batch_kwargs )
272
288
273
289
@override
0 commit comments