@@ -87,7 +87,7 @@ def any_lightning_module_function_or_hook(self):
87
87
self.logger.experiment.whatever_ml_flow_supports(...)
88
88
89
89
Args:
90
- experiment_name: The name of the experiment
90
+ experiment_name: The name of the experiment.
91
91
run_name: Name of the new run. The `run_name` is internally stored as a ``mlflow.runName`` tag.
92
92
If the ``mlflow.runName`` tag has already been set in `tags`, the value is overridden by the `run_name`.
93
93
tracking_uri: Address of local or remote tracking server.
@@ -100,6 +100,7 @@ def any_lightning_module_function_or_hook(self):
100
100
prefix: A string to put at the beginning of metric keys.
101
101
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
102
102
default.
103
+ run_id: The run identifier of the experiment. If not provided, a new run is started.
103
104
104
105
Raises:
105
106
ModuleNotFoundError:
@@ -117,6 +118,7 @@ def __init__(
117
118
save_dir : Optional [str ] = "./mlruns" ,
118
119
prefix : str = "" ,
119
120
artifact_location : Optional [str ] = None ,
121
+ run_id : Optional [str ] = None ,
120
122
):
121
123
if mlflow is None :
122
124
raise ModuleNotFoundError (
@@ -130,11 +132,13 @@ def __init__(
130
132
self ._experiment_id = None
131
133
self ._tracking_uri = tracking_uri
132
134
self ._run_name = run_name
133
- self ._run_id = None
135
+ self ._run_id = run_id
134
136
self .tags = tags
135
137
self ._prefix = prefix
136
138
self ._artifact_location = artifact_location
137
139
140
+ self ._initialized = False
141
+
138
142
self ._mlflow_client = MlflowClient (tracking_uri )
139
143
140
144
@property
@@ -149,6 +153,16 @@ def experiment(self) -> MlflowClient:
149
153
self.logger.experiment.some_mlflow_function()
150
154
151
155
"""
156
+
157
+ if self ._initialized :
158
+ return self ._mlflow_client
159
+
160
+ if self ._run_id is not None :
161
+ run = self ._mlflow_client .get_run (self ._run_id )
162
+ self ._experiment_id = run .info .experiment_id
163
+ self ._initialized = True
164
+ return self ._mlflow_client
165
+
152
166
if self ._experiment_id is None :
153
167
expt = self ._mlflow_client .get_experiment_by_name (self ._experiment_name )
154
168
if expt is not None :
@@ -169,6 +183,7 @@ def experiment(self) -> MlflowClient:
169
183
self .tags [MLFLOW_RUN_NAME ] = self ._run_name
170
184
run = self ._mlflow_client .create_run (experiment_id = self ._experiment_id , tags = resolve_tags (self .tags ))
171
185
self ._run_id = run .info .run_id
186
+ self ._initialized = True
172
187
return self ._mlflow_client
173
188
174
189
@property
0 commit comments