Skip to content

Commit d06e304

Browse files
authored
Merge pull request #224 from dataiku/feature/sc-82825-add-the-possibility-to-deploy-a-run-s-model
Add a set_run_classes method to MLflow extension, to store the classe…
2 parents 926eda6 + 969a4e5 commit d06e304

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

dataikuapi/dss/mlflow.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
13
class DSSMLflowExtension(object):
24
"""
35
A handle to interact with specific endpoints of the DSS MLflow integration.
@@ -131,3 +133,68 @@ def clean_experiment_tracking_db(self):
131133
This call requires an API key with admin rights
132134
"""
133135
self.client._perform_raw("DELETE", "/api/2.0/mlflow/extension/clean-db/%s" % self.project_key)
136+
137+
def set_run_inference_info(self, run_id, model_type, classes=None, code_env_name=None, target=None):
138+
"""
139+
Sets the type of the model, and optionally other information useful to deploy or evaluate it.
140+
141+
model_type must be one of:
142+
- REGRESSION
143+
- BINARY_CLASSIFICATION
144+
- MULTICLASS
145+
- OTHER
146+
147+
Classes must be specified if and only if the model is a BINARY_CLASSIFICATION or MULTICLASS model.
148+
149+
This information is leveraged to filter saved models on their prediction type and prefill the classes
150+
when deploying using the GUI an MLflow model as a version of a DSS Saved Model.
151+
152+
:param model_type: prediction type (see doc)
153+
:type model_type: str
154+
:param run_id: run_id for which to set the classes
155+
:type run_id: str
156+
:param classes: ordered list of classes (not for all prediction types, see doc)
157+
:type classes: list(str)
158+
:param code_env_name: name of an adequate DSS python code environment
159+
:type code_env_name: str
160+
:param target: name of the target
161+
:type target: str
162+
"""
163+
if model_type not in {"REGRESSION", "BINARY_CLASSIFICATION", "MULTICLASS", "OTHER"}:
164+
raise ValueError('Invalid prediction type: {}'.format(model_type))
165+
166+
if classes and model_type not in {"BINARY_CLASSIFICATION", "MULTICLASS"}:
167+
raise ValueError('Classes can be specified only for BINARY_CLASSIFICATION or MULTICLASS prediction types')
168+
if model_type in {"BINARY_CLASSIFICATION", "MULTICLASS"}:
169+
if not classes:
170+
raise ValueError('Classes must be specified for {} prediction type'.format(model_type))
171+
if not isinstance(classes, list):
172+
raise ValueError('Wrong type for classes: {}'.format(type(classes)))
173+
for cur_class in classes:
174+
if cur_class is None:
175+
raise ValueError('class can not be None')
176+
if not isinstance(cur_class, str):
177+
raise ValueError('Wrong type for class {}: {}'.format(cur_class, type(cur_class)))
178+
179+
if code_env_name and not isinstance(code_env_name, str):
180+
raise ValueError('code_env_name must be a string')
181+
if target and not isinstance(target, str):
182+
raise ValueError('target must be a string')
183+
184+
params = {
185+
"run_id": run_id,
186+
"prediction_type": model_type
187+
}
188+
189+
if classes:
190+
params["classes"] = json.dumps(classes)
191+
if code_env_name:
192+
params["code_env_name"] = code_env_name
193+
if target:
194+
params["target"] = target
195+
196+
self.client._perform_http(
197+
"POST", "/api/2.0/mlflow/extension/set-run-inference-info",
198+
headers={"x-dku-mlflow-project-key": self.project_key},
199+
body=params
200+
)

0 commit comments

Comments
 (0)