Skip to content

Commit 41d52f9

Browse files
committed
Add better support for custom & plugin algos + few fixes
1 parent 2bd617b commit 41d52f9

File tree

1 file changed

+85
-15
lines changed

1 file changed

+85
-15
lines changed

dataikuapi/dss/ml.py

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,16 @@ def use_feature(self, feature_name):
238238
def get_algorithm_settings(self, algorithm_name):
239239
raise NotImplementedError()
240240

241+
def _get_custom_algorithm_settings(self, algorithm_name):
242+
# returns the first algorithm with this name
243+
for algo in self.mltask_settings["modeling"]["custom_mllib"]:
244+
if algorithm_name == algo["name"]:
245+
return algo
246+
for algo in self.mltask_settings["modeling"]["custom_python"]:
247+
if algorithm_name == algo["name"]:
248+
return algo
249+
raise ValueError("Unknown algorithm: {}".format(algorithm_name))
250+
241251
def get_diagnostics_settings(self):
242252
"""
243253
Gets the diagnostics settings for a mltask. This returns a reference to the
@@ -307,31 +317,38 @@ def disable_all_algorithms(self):
307317
custom_mllib["enabled"] = False
308318
for custom_python in self.mltask_settings["modeling"]["custom_python"]:
309319
custom_python["enabled"] = False
310-
for plugin in self.mltask_settings["modeling"]["plugin_python"].values():
320+
for plugin in self.mltask_settings["modeling"].get("plugin_python", {}).values():
311321
plugin["enabled"] = False
312322

313323
def get_all_possible_algorithm_names(self):
314324
"""
315325
Returns the list of possible algorithm names, i.e. the list of valid
316326
identifiers for :meth:`set_algorithm_enabled` and :meth:`get_algorithm_settings`
317327
318-
This does not include Custom Python models, Custom MLLib models, plugin models.
319328
This includes all possible algorithms, regardless of the prediction kind (regression/classification)
320329
or engine, so some algorithms may be irrelevant
321330
322331
:returns: the list of algorithm names as a list of strings
323332
:rtype: list of string
324333
"""
325-
return list(self.__class__.algorithm_remap.keys())
334+
return list(self.__class__.algorithm_remap.keys()) + self._get_custom_algorithm_names()
335+
336+
def _get_custom_algorithm_names(self):
337+
"""
338+
Returns the list of names of defined custom models (python & mllib)
339+
340+
:returns: the list of custom models names
341+
:rtype: list of string
342+
"""
343+
return [algo["name"] for algo in self.mltask_settings["modeling"]["custom_mllib"]]\
344+
+ [algo["name"] for algo in self.mltask_settings["modeling"]["custom_python"]]
326345

327346
def get_enabled_algorithm_names(self):
328347
"""
329348
:returns: the list of enabled algorithm names as a list of strings
330349
:rtype: list of string
331350
"""
332-
algos = self.__class__.algorithm_remap
333-
algo_names = [algo_name for algo_name in algos.keys() if self.mltask_settings["modeling"][algos[algo_name].algorithm_name.lower()]["enabled"]]
334-
return algo_names
351+
return [algo_name for algo_name in self.get_all_possible_algorithm_names() if self.get_algorithm_settings(algo_name).get("enabled", False)]
335352

336353
def get_enabled_algorithm_settings(self):
337354
"""
@@ -356,6 +373,32 @@ def set_metric(self, metric=None, custom_metric=None, custom_metric_greater_is_b
356373
self.mltask_settings["modeling"]["metrics"]["customEvaluationMetricGIB"] = custom_metric_greater_is_better
357374
self.mltask_settings["modeling"]["metrics"]["customEvaluationMetricNeedsProba"] = custom_metric_use_probas
358375

376+
def add_custom_python_model(self, name="Custom Python Model", code=""):
377+
"""
378+
Adds a new custom python model
379+
380+
:param str name: name of the custom model
381+
:param str code: code of the custom model
382+
"""
383+
self.mltask_settings["modeling"]["custom_python"].append({
384+
"name": name,
385+
"code": code,
386+
"enabled": True
387+
})
388+
389+
def add_custom_mllib_model(self, name="Custom MLlib Model", code=""):
390+
"""
391+
Adds a new custom mllib model
392+
393+
:param str name: name of the custom model
394+
:param str code: code of the custom model
395+
"""
396+
self.mltask_settings["modeling"]["custom_mllib"].append({
397+
"name": name,
398+
"initializationCode": code,
399+
"enabled": True
400+
})
401+
359402
def save(self):
360403
"""Saves back these settings to the ML Task"""
361404

@@ -1310,7 +1353,6 @@ def __init__(self, raw_settings, hyperparameter_search_params):
13101353

13111354
self.cache_node_ids = self._register_simple_parameter("cache_node_ids")
13121355
self.checkpoint_interval = self._register_single_value_hyperparameter("checkpoint_interval", accepted_types=[int])
1313-
self.impurity = self._register_single_category_hyperparameter("impurity", accepted_values=["gini", "entropy", "variance"]) # TODO: distinguish between regression and classif
13141356
self.max_bins = self._register_single_value_hyperparameter("max_bins", accepted_types=[int])
13151357
self.max_memory_mb = self._register_simple_parameter("max_memory_mb")
13161358
self.min_info_gain = self._register_single_value_hyperparameter("min_info_gain", accepted_types=[int, float])
@@ -1395,20 +1437,41 @@ def __init__(self, client, project_key, analysis_id, mltask_id, mltask_settings)
13951437
def get_prediction_type(self):
13961438
return self.mltask_settings['predictionType']
13971439

1440+
def get_all_possible_algorithm_names(self):
1441+
"""
1442+
Returns the list of possible algorithm names, i.e. the list of valid
1443+
identifiers for :meth:`set_algorithm_enabled` and :meth:`get_algorithm_settings`
1444+
1445+
This includes all possible algorithms, regardless of the prediction kind (regression/classification)
1446+
or engine, so some algorithms may be irrelevant
1447+
1448+
:returns: the list of algorithm names as a list of strings
1449+
:rtype: list of string
1450+
"""
1451+
return super(DSSPredictionMLTaskSettings, self).get_all_possible_algorithm_names() + self._get_plugin_algorithm_names()
1452+
1453+
def _get_plugin_algorithm_names(self):
1454+
return self.mltask_settings["modeling"]["plugin_python"].keys()
1455+
1456+
def _get_plugin_algorithm_settings(self, algorithm_name):
1457+
if algorithm_name in self.mltask_settings["modeling"]["plugin_python"]:
1458+
return self.mltask_settings["modeling"]["plugin_python"][algorithm_name]
1459+
raise ValueError("Unknown algorithm: {}".format(algorithm_name))
1460+
13981461
def get_enabled_algorithm_names(self):
13991462
"""
14001463
:returns: the list of enabled algorithm names as a list of strings
14011464
:rtype: list of string
14021465
"""
1403-
algos = self.__class__.algorithm_remap
1466+
algo_names = super(DSSPredictionMLTaskSettings, self).get_enabled_algorithm_names()
1467+
14041468
# Hide either "XGBOOST_CLASSIFICATION" or "XGBOOST_REGRESSION" which point to the same key "xgboost"
14051469
if self.mltask_settings["predictionType"] == "REGRESSION":
1406-
excluded_name = {"XGBOOST_CLASSIFICATION"}
1470+
excluded_names = {"XGBOOST_CLASSIFICATION"}
14071471
else:
1408-
excluded_name = {"XGBOOST_REGRESSION"}
1409-
algo_names = [algo_name for algo_name in algos.keys() if (self.mltask_settings["modeling"][algos[algo_name].algorithm_name.lower()]["enabled"]
1410-
and algo_name not in excluded_name)]
1411-
return algo_names
1472+
excluded_names = {"XGBOOST_REGRESSION"}
1473+
1474+
return [algo_name for algo_name in algo_names if algo_name not in excluded_names]
14121475

14131476
def get_algorithm_settings(self, algorithm_name):
14141477
"""
@@ -1442,6 +1505,10 @@ def get_algorithm_settings(self, algorithm_name):
14421505
# Subsequent calls get the same object
14431506
self.mltask_settings["modeling"][algorithm_name.lower()] = algorithm_settings
14441507
return self.mltask_settings["modeling"][algorithm_name.lower()]
1508+
elif algorithm_name in self._get_custom_algorithm_names():
1509+
return self._get_custom_algorithm_settings(algorithm_name)
1510+
elif algorithm_name in self._get_plugin_algorithm_names():
1511+
return self._get_plugin_algorithm_settings(algorithm_name)
14451512
else:
14461513
raise ValueError("Unknown algorithm: {}".format(algorithm_name))
14471514

@@ -1590,8 +1657,11 @@ def get_algorithm_settings(self, algorithm_name):
15901657
"""
15911658
if algorithm_name in self.__class__.algorithm_remap:
15921659
algorithm_name = self.__class__.algorithm_remap[algorithm_name]
1593-
1594-
return self.mltask_settings["modeling"][algorithm_name.lower()]
1660+
return self.mltask_settings["modeling"][algorithm_name.lower()]
1661+
elif algorithm_name in self._get_custom_algorithm_names():
1662+
return self._get_custom_algorithm_settings(algorithm_name)
1663+
else:
1664+
raise ValueError("Unknown algorithm: {}".format(algorithm_name))
15951665

15961666

15971667
class DSSTrainedModelDetails(object):

0 commit comments

Comments
 (0)