Skip to content

Commit 3e60674

Browse files
Merge pull request #117 from dataiku/enhancement/dss90-better-suport-custom-plugin-algos
Better suport custom & plugin algos
2 parents 2bd617b + 3d75fcf commit 3e60674

File tree

4 files changed

+118
-26
lines changed

4 files changed

+118
-26
lines changed

dataikuapi/dss/analysis.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,9 @@ def create_prediction_ml_task(self,
188188
return mltask
189189

190190
def create_clustering_ml_task(self,
191-
ml_backend_type = "PY_MEMORY",
192-
guess_policy = "KMEANS"):
191+
ml_backend_type="PY_MEMORY",
192+
guess_policy="KMEANS",
193+
wait_guess_complete=True):
193194

194195

195196
"""Creates a new clustering task in a new visual analysis lab
@@ -205,6 +206,10 @@ def create_clustering_ml_task(self,
205206
206207
:param string ml_backend_type: ML backend to use, one of PY_MEMORY, MLLIB or H2O
207208
:param string guess_policy: Policy to use for setting the default parameters. Valid values are: KMEANS and ANOMALY_DETECTION
209+
:param boolean wait_guess_complete: if False, the returned ML task will be in 'guessing' state, i.e. analyzing the input dataset to determine feature handling and algorithms.
210+
You should wait for the guessing to be completed by calling
211+
``wait_guess_complete`` on the returned object before doing anything
212+
else (in particular calling ``train`` or ``get_settings``)
208213
"""
209214

210215
obj = {
@@ -214,7 +219,11 @@ def create_clustering_ml_task(self,
214219
}
215220

216221
ref = self.client._perform_json("POST", "/projects/%s/lab/%s/models/" % (self.project_key, self.analysis_id), body=obj)
217-
return DSSMLTask(self.client, self.project_key, self.analysis_id, ref["mlTaskId"])
222+
mltask = DSSMLTask(self.client, self.project_key, self.analysis_id, ref["mlTaskId"])
223+
224+
if wait_guess_complete:
225+
mltask.wait_guess_complete()
226+
return mltask
218227

219228
def list_ml_tasks(self):
220229
"""

dataikuapi/dss/dataset.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,9 @@ def create_prediction_ml_task(self, target_variable,
385385
guess_policy = guess_policy, prediction_type = prediction_type, wait_guess_complete = wait_guess_complete)
386386

387387
def create_clustering_ml_task(self, input_dataset,
388-
ml_backend_type = "PY_MEMORY",
389-
guess_policy = "KMEANS"):
388+
ml_backend_type="PY_MEMORY",
389+
guess_policy="KMEANS",
390+
wait_guess_complete=True):
390391
"""Creates a new clustering task in a new visual analysis lab
391392
for a dataset.
392393
@@ -400,9 +401,13 @@ def create_clustering_ml_task(self, input_dataset,
400401
401402
:param string ml_backend_type: ML backend to use, one of PY_MEMORY, MLLIB or H2O
402403
:param string guess_policy: Policy to use for setting the default parameters. Valid values are: KMEANS and ANOMALY_DETECTION
404+
:param boolean wait_guess_complete: if False, the returned ML task will be in 'guessing' state, i.e. analyzing the input dataset to determine feature handling and algorithms.
405+
You should wait for the guessing to be completed by calling
406+
``wait_guess_complete`` on the returned object before doing anything
407+
else (in particular calling ``train`` or ``get_settings``)
403408
"""
404-
return self.project.create_clustering_ml_task(self.dataset_name,
405-
ml_backend_type = ml_backend_type, guess_policy = guess_policy)
409+
return self.project.create_clustering_ml_task(self.dataset_name, ml_backend_type=ml_backend_type, guess_policy=guess_policy,
410+
wait_guess_complete=wait_guess_complete)
406411

407412
def create_analysis(self):
408413
"""

dataikuapi/dss/ml.py

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,16 @@ def use_feature(self, feature_name):
238238
def get_algorithm_settings(self, algorithm_name):
239239
raise NotImplementedError()
240240

241+
def _get_custom_algorithm_settings(self, algorithm_name):
242+
# returns the first algorithm with this name
243+
for algo in self.mltask_settings["modeling"]["custom_mllib"]:
244+
if algorithm_name == algo["name"]:
245+
return algo
246+
for algo in self.mltask_settings["modeling"]["custom_python"]:
247+
if algorithm_name == algo["name"]:
248+
return algo
249+
raise ValueError("Unknown algorithm: {}".format(algorithm_name))
250+
241251
def get_diagnostics_settings(self):
242252
"""
243253
Gets the diagnostics settings for a mltask. This returns a reference to the
@@ -307,31 +317,38 @@ def disable_all_algorithms(self):
307317
custom_mllib["enabled"] = False
308318
for custom_python in self.mltask_settings["modeling"]["custom_python"]:
309319
custom_python["enabled"] = False
310-
for plugin in self.mltask_settings["modeling"]["plugin_python"].values():
320+
for plugin in self.mltask_settings["modeling"].get("plugin_python", {}).values():
311321
plugin["enabled"] = False
312322

313323
def get_all_possible_algorithm_names(self):
314324
"""
315325
Returns the list of possible algorithm names, i.e. the list of valid
316326
identifiers for :meth:`set_algorithm_enabled` and :meth:`get_algorithm_settings`
317327
318-
This does not include Custom Python models, Custom MLLib models, plugin models.
319328
This includes all possible algorithms, regardless of the prediction kind (regression/classification)
320329
or engine, so some algorithms may be irrelevant
321330
322331
:returns: the list of algorithm names as a list of strings
323332
:rtype: list of string
324333
"""
325-
return list(self.__class__.algorithm_remap.keys())
334+
return list(self.__class__.algorithm_remap.keys()) + self._get_custom_algorithm_names()
335+
336+
def _get_custom_algorithm_names(self):
337+
"""
338+
Returns the list of names of defined custom models (Python & MLlib backends)
339+
340+
:returns: the list of custom models names
341+
:rtype: list of string
342+
"""
343+
return ([algo["name"] for algo in self.mltask_settings["modeling"]["custom_mllib"]]
344+
+ [algo["name"] for algo in self.mltask_settings["modeling"]["custom_python"]])
326345

327346
def get_enabled_algorithm_names(self):
328347
"""
329348
:returns: the list of enabled algorithm names as a list of strings
330349
:rtype: list of string
331350
"""
332-
algos = self.__class__.algorithm_remap
333-
algo_names = [algo_name for algo_name in algos.keys() if self.mltask_settings["modeling"][algos[algo_name].algorithm_name.lower()]["enabled"]]
334-
return algo_names
351+
return [algo_name for algo_name in self.get_all_possible_algorithm_names() if self.get_algorithm_settings(algo_name).get("enabled", False)]
335352

336353
def get_enabled_algorithm_settings(self):
337354
"""
@@ -356,6 +373,32 @@ def set_metric(self, metric=None, custom_metric=None, custom_metric_greater_is_b
356373
self.mltask_settings["modeling"]["metrics"]["customEvaluationMetricGIB"] = custom_metric_greater_is_better
357374
self.mltask_settings["modeling"]["metrics"]["customEvaluationMetricNeedsProba"] = custom_metric_use_probas
358375

376+
def add_custom_python_model(self, name="Custom Python Model", code=""):
377+
"""
378+
Adds a new custom python model
379+
380+
:param str name: name of the custom model
381+
:param str code: code of the custom model
382+
"""
383+
self.mltask_settings["modeling"]["custom_python"].append({
384+
"name": name,
385+
"code": code,
386+
"enabled": True
387+
})
388+
389+
def add_custom_mllib_model(self, name="Custom MLlib Model", code=""):
390+
"""
391+
Adds a new custom MLlib model
392+
393+
:param str name: name of the custom model
394+
:param str code: code of the custom model
395+
"""
396+
self.mltask_settings["modeling"]["custom_mllib"].append({
397+
"name": name,
398+
"initializationCode": code,
399+
"enabled": True
400+
})
401+
359402
def save(self):
360403
"""Saves back these settings to the ML Task"""
361404

@@ -1310,7 +1353,6 @@ def __init__(self, raw_settings, hyperparameter_search_params):
13101353

13111354
self.cache_node_ids = self._register_simple_parameter("cache_node_ids")
13121355
self.checkpoint_interval = self._register_single_value_hyperparameter("checkpoint_interval", accepted_types=[int])
1313-
self.impurity = self._register_single_category_hyperparameter("impurity", accepted_values=["gini", "entropy", "variance"]) # TODO: distinguish between regression and classif
13141356
self.max_bins = self._register_single_value_hyperparameter("max_bins", accepted_types=[int])
13151357
self.max_memory_mb = self._register_simple_parameter("max_memory_mb")
13161358
self.min_info_gain = self._register_single_value_hyperparameter("min_info_gain", accepted_types=[int, float])
@@ -1395,20 +1437,41 @@ def __init__(self, client, project_key, analysis_id, mltask_id, mltask_settings)
13951437
def get_prediction_type(self):
13961438
return self.mltask_settings['predictionType']
13971439

1440+
def get_all_possible_algorithm_names(self):
1441+
"""
1442+
Returns the list of possible algorithm names, i.e. the list of valid
1443+
identifiers for :meth:`set_algorithm_enabled` and :meth:`get_algorithm_settings`
1444+
1445+
This includes all possible algorithms, regardless of the prediction kind (regression/classification)
1446+
or engine, so some algorithms may be irrelevant
1447+
1448+
:returns: the list of algorithm names as a list of strings
1449+
:rtype: list of string
1450+
"""
1451+
return super(DSSPredictionMLTaskSettings, self).get_all_possible_algorithm_names() + self._get_plugin_algorithm_names()
1452+
1453+
def _get_plugin_algorithm_names(self):
1454+
return list(self.mltask_settings["modeling"]["plugin_python"].keys())
1455+
1456+
def _get_plugin_algorithm_settings(self, algorithm_name):
1457+
if algorithm_name in self.mltask_settings["modeling"]["plugin_python"]:
1458+
return self.mltask_settings["modeling"]["plugin_python"][algorithm_name]
1459+
raise ValueError("Unknown algorithm: {}".format(algorithm_name))
1460+
13981461
def get_enabled_algorithm_names(self):
13991462
"""
14001463
:returns: the list of enabled algorithm names as a list of strings
14011464
:rtype: list of string
14021465
"""
1403-
algos = self.__class__.algorithm_remap
1466+
algo_names = super(DSSPredictionMLTaskSettings, self).get_enabled_algorithm_names()
1467+
14041468
# Hide either "XGBOOST_CLASSIFICATION" or "XGBOOST_REGRESSION" which point to the same key "xgboost"
14051469
if self.mltask_settings["predictionType"] == "REGRESSION":
1406-
excluded_name = {"XGBOOST_CLASSIFICATION"}
1470+
excluded_names = {"XGBOOST_CLASSIFICATION"}
14071471
else:
1408-
excluded_name = {"XGBOOST_REGRESSION"}
1409-
algo_names = [algo_name for algo_name in algos.keys() if (self.mltask_settings["modeling"][algos[algo_name].algorithm_name.lower()]["enabled"]
1410-
and algo_name not in excluded_name)]
1411-
return algo_names
1472+
excluded_names = {"XGBOOST_REGRESSION"}
1473+
1474+
return [algo_name for algo_name in algo_names if algo_name not in excluded_names]
14121475

14131476
def get_algorithm_settings(self, algorithm_name):
14141477
"""
@@ -1442,6 +1505,10 @@ def get_algorithm_settings(self, algorithm_name):
14421505
# Subsequent calls get the same object
14431506
self.mltask_settings["modeling"][algorithm_name.lower()] = algorithm_settings
14441507
return self.mltask_settings["modeling"][algorithm_name.lower()]
1508+
elif algorithm_name in self._get_custom_algorithm_names():
1509+
return self._get_custom_algorithm_settings(algorithm_name)
1510+
elif algorithm_name in self._get_plugin_algorithm_names():
1511+
return self._get_plugin_algorithm_settings(algorithm_name)
14451512
else:
14461513
raise ValueError("Unknown algorithm: {}".format(algorithm_name))
14471514

@@ -1590,8 +1657,11 @@ def get_algorithm_settings(self, algorithm_name):
15901657
"""
15911658
if algorithm_name in self.__class__.algorithm_remap:
15921659
algorithm_name = self.__class__.algorithm_remap[algorithm_name]
1593-
1594-
return self.mltask_settings["modeling"][algorithm_name.lower()]
1660+
return self.mltask_settings["modeling"][algorithm_name.lower()]
1661+
elif algorithm_name in self._get_custom_algorithm_names():
1662+
return self._get_custom_algorithm_settings(algorithm_name)
1663+
else:
1664+
raise ValueError("Unknown algorithm: {}".format(algorithm_name))
15951665

15961666

15971667
class DSSTrainedModelDetails(object):

dataikuapi/dss/project.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,9 @@ def create_prediction_ml_task(self, input_dataset, target_variable,
532532
return ret
533533

534534
def create_clustering_ml_task(self, input_dataset,
535-
ml_backend_type = "PY_MEMORY",
536-
guess_policy = "KMEANS"):
537-
535+
ml_backend_type = "PY_MEMORY",
536+
guess_policy = "KMEANS",
537+
wait_guess_complete=True):
538538

539539
"""Creates a new clustering task in a new visual analysis lab
540540
for a dataset.
@@ -549,6 +549,10 @@ def create_clustering_ml_task(self, input_dataset,
549549
550550
:param string ml_backend_type: ML backend to use, one of PY_MEMORY, MLLIB or H2O
551551
:param string guess_policy: Policy to use for setting the default parameters. Valid values are: KMEANS and ANOMALY_DETECTION
552+
:param boolean wait_guess_complete: if False, the returned ML task will be in 'guessing' state, i.e. analyzing the input dataset to determine feature handling and algorithms.
553+
You should wait for the guessing to be completed by calling
554+
``wait_guess_complete`` on the returned object before doing anything
555+
else (in particular calling ``train`` or ``get_settings``)
552556
"""
553557

554558
obj = {
@@ -559,7 +563,11 @@ def create_clustering_ml_task(self, input_dataset,
559563
}
560564

561565
ref = self.client._perform_json("POST", "/projects/%s/models/lab/" % self.project_key, body=obj)
562-
return DSSMLTask(self.client, self.project_key, ref["analysisId"], ref["mlTaskId"])
566+
mltask = DSSMLTask(self.client, self.project_key, ref["analysisId"], ref["mlTaskId"])
567+
568+
if wait_guess_complete:
569+
mltask.wait_guess_complete()
570+
return mltask
563571

564572
def list_ml_tasks(self):
565573
"""

0 commit comments

Comments
 (0)