Skip to content

Commit 1cb462f

Browse files
pgagarinovrohitgr7tchaton
authored andcommitted
Fixed a crash bug in MLFlow logger (#4716)
* warnings.warn doesn't accept tuples, which causes "TypeError: expected string or bytes-like object" when the execution flow gets to this warning. Fixed that. * Try adding a mock test * Try adding a mock test Co-authored-by: rohitgr7 <[email protected]> Co-authored-by: chaton <[email protected]> (cherry picked from commit 70361eb)
1 parent 684d152 commit 1cb462f

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

pytorch_lightning/loggers/mlflow.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
-------------
1818
"""
1919
import re
20-
import warnings
2120
from argparse import Namespace
2221
from time import time
2322
from typing import Any, Dict, Optional, Union
@@ -32,7 +31,7 @@
3231

3332
from pytorch_lightning import _logger as log
3433
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
35-
from pytorch_lightning.utilities import rank_zero_only
34+
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
3635

3736
LOCAL_FILE_URI_PREFIX = "file:"
3837

@@ -158,9 +157,11 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
158157

159158
new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
160159
if k != new_k:
161-
warnings.warn(("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name.\n",
162-
f"Replacing {k} with {new_k}."))
163-
k = new_k
160+
rank_zero_warn(
161+
"MLFlow only allows '_', '/', '.' and ' ' special characters in metric name."
162+
f" Replacing {k} with {new_k}.", RuntimeWarning
163+
)
164+
k = new_k
164165

165166
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
166167

tests/loggers/test_mlflow.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,24 @@ def test_mlflow_logger_dirs_creation(tmpdir):
150150
@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
151151
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
152152
def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir):
153+
"""
154+
Test that the logger experiment_id retrieved only once.
155+
"""
153156
logger = MLFlowLogger('test', save_dir=tmpdir)
154157
_ = logger.experiment
155158
_ = logger.experiment
156159
_ = logger.experiment
157160
assert logger.experiment.get_experiment_by_name.call_count == 1
161+
162+
163+
@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
164+
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
165+
def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir):
166+
"""
167+
Test that the logger raises warning with special characters not accepted by MLFlow.
168+
"""
169+
logger = MLFlowLogger('test', save_dir=tmpdir)
170+
metrics = {'[some_metric]': 10}
171+
172+
with pytest.warns(RuntimeWarning, match='special characters in metric name'):
173+
logger.log_metrics(metrics)

0 commit comments

Comments
 (0)