Skip to content

Commit e618a33

Browse files
Kr4isawaelchlipre-commit-ci[bot]
authored
Allow log to an existing run ID in MLflow with MLFlowLogger (#12290)
Co-authored-by: bruno.cabado <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5be6720 commit e618a33

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Allow logging to an existing run ID in MLflow with `MLFlowLogger` ([#12290](https://github.com/PyTorchLightning/pytorch-lightning/pull/12290))
13+
14+
1215
- Enable gradient accumulation using Horovod's `backward_passes_per_step` ([#11911](https://github.com/PyTorchLightning/pytorch-lightning/pull/11911))
1316

1417

pytorch_lightning/loggers/mlflow.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def any_lightning_module_function_or_hook(self):
8787
self.logger.experiment.whatever_ml_flow_supports(...)
8888
8989
Args:
90-
experiment_name: The name of the experiment
90+
experiment_name: The name of the experiment.
9191
run_name: Name of the new run. The `run_name` is internally stored as a ``mlflow.runName`` tag.
9292
If the ``mlflow.runName`` tag has already been set in `tags`, the value is overridden by the `run_name`.
9393
tracking_uri: Address of local or remote tracking server.
@@ -100,6 +100,7 @@ def any_lightning_module_function_or_hook(self):
100100
prefix: A string to put at the beginning of metric keys.
101101
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
102102
default.
103+
run_id: The run identifier of the experiment. If not provided, a new run is started.
103104
104105
Raises:
105106
ModuleNotFoundError:
@@ -117,6 +118,7 @@ def __init__(
117118
save_dir: Optional[str] = "./mlruns",
118119
prefix: str = "",
119120
artifact_location: Optional[str] = None,
121+
run_id: Optional[str] = None,
120122
):
121123
if mlflow is None:
122124
raise ModuleNotFoundError(
@@ -130,11 +132,13 @@ def __init__(
130132
self._experiment_id = None
131133
self._tracking_uri = tracking_uri
132134
self._run_name = run_name
133-
self._run_id = None
135+
self._run_id = run_id
134136
self.tags = tags
135137
self._prefix = prefix
136138
self._artifact_location = artifact_location
137139

140+
self._initialized = False
141+
138142
self._mlflow_client = MlflowClient(tracking_uri)
139143

140144
@property
@@ -149,6 +153,16 @@ def experiment(self) -> MlflowClient:
149153
self.logger.experiment.some_mlflow_function()
150154
151155
"""
156+
157+
if self._initialized:
158+
return self._mlflow_client
159+
160+
if self._run_id is not None:
161+
run = self._mlflow_client.get_run(self._run_id)
162+
self._experiment_id = run.info.experiment_id
163+
self._initialized = True
164+
return self._mlflow_client
165+
152166
if self._experiment_id is None:
153167
expt = self._mlflow_client.get_experiment_by_name(self._experiment_name)
154168
if expt is not None:
@@ -169,6 +183,7 @@ def experiment(self) -> MlflowClient:
169183
self.tags[MLFLOW_RUN_NAME] = self._run_name
170184
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=resolve_tags(self.tags))
171185
self._run_id = run.info.run_id
186+
self._initialized = True
172187
return self._mlflow_client
173188

174189
@property

tests/loggers/test_mlflow.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def test_mlflow_logger_exists(client, mlflow, tmpdir):
4040

4141
run1 = MagicMock()
4242
run1.info.run_id = "run-id-1"
43+
run1.info.experiment_id = "exp-id-1"
4344

4445
run2 = MagicMock()
4546
run2.info.run_id = "run-id-2"
@@ -113,6 +114,27 @@ def test_mlflow_run_name_setting(client, mlflow, tmpdir):
113114
client.return_value.create_run.assert_called_with(experiment_id="exp-id", tags=default_tags)
114115

115116

117+
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
118+
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
119+
def test_mlflow_run_id_setting(client, mlflow, tmpdir):
120+
"""Test that the run_id argument uses the provided run_id."""
121+
122+
run = MagicMock()
123+
run.info.run_id = "run-id"
124+
run.info.experiment_id = "experiment-id"
125+
126+
# simulate existing run
127+
client.return_value.get_run = MagicMock(return_value=run)
128+
129+
# run_id exists uses the existing run
130+
logger = MLFlowLogger("test", run_id=run.info.run_id, save_dir=tmpdir)
131+
_ = logger.experiment
132+
client.return_value.get_run.assert_called_with(run.info.run_id)
133+
assert logger.experiment_id == run.info.experiment_id
134+
assert logger.run_id == run.info.run_id
135+
client.reset_mock(return_value=True)
136+
137+
116138
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
117139
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
118140
def test_mlflow_log_dir(client, mlflow, tmpdir):

0 commit comments

Comments
 (0)