Skip to content

Commit ab37ab3

Browse files
committed
Filter out XGBOOST_{REGRESSION,CLASSIFICATION} in get_enabled_algorithm_names according to predictionType
1 parent 864781b commit ab37ab3

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

dataikuapi/dss/ml.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,20 @@ def __init__(self, client, project_key, analysis_id, mltask_id, mltask_settings)
13551355
def get_prediction_type(self):
13561356
return self.mltask_settings['predictionType']
13571357

1358+
def get_enabled_algorithm_names(self):
1359+
"""
1360+
:returns: the list of enabled algorithm names as a list of strings
1361+
:rtype: list of string
1362+
"""
1363+
algos = self.__class__.algorithm_remap
1364+
if self.mltask_settings["predictionType"] == "REGRESSION":
1365+
excluded_name = {"XGBOOST_CLASSIFICATION"}
1366+
else:
1367+
excluded_name = {"XGBOOST_REGRESSION"}
1368+
algo_names = [algo_name for algo_name in algos.keys() if (self.mltask_settings["modeling"][algos[algo_name].algorithm_name.lower()]["enabled"]
1369+
and algo_name not in excluded_name)]
1370+
return algo_names
1371+
13581372
def get_algorithm_settings(self, algorithm_name):
13591373
"""
13601374
Gets the training settings for a particular algorithm. This returns a reference to the

0 commit comments

Comments
 (0)