Skip to content

Commit db146fe

Browse files
committed
Migrate execution parameters out of the HyperparameterSettings hierarchy
1 parent f7c89ce commit db146fe

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

dataikuapi/dss/ml.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -957,8 +957,14 @@ def __setattr__(self, attr_name, value):
957957
target.set_value(value)
958958
elif isinstance(target, CategoricalHyperparameterSettings):
959959
target.set_values(value)
960-
else:
960+
elif isinstance(target, NumericalHyperparameterSettings):
961961
raise Exception("Invalid assignment of a NumericalHyperparameterSettings object")
962+
else:
963+
# simple parameter
964+
assert isinstance(value, type(target)), "Invalid type {} for parameter {}: expected {}".format(type(value), attr_name, type(target))
965+
super(PredictionAlgorithmSettings, self).__setattr__(attr_name, value)
966+
self[attr_name] = value
967+
self._hyperparameters_registry[attr_name] = value
962968
else:
963969
# other cases (properties setter, new attribute...)
964970
super(PredictionAlgorithmSettings, self).__setattr__(attr_name, value)
@@ -987,11 +993,18 @@ def _register_single_value_hyperparameter(self, json_key, accepted_types=None, a
987993
self._hyperparameters_registry[json_key] = SingleValueHyperparameterSettings(json_key, self, accepted_types=accepted_types)
988994
return self._hyperparameters_registry[json_key]
989995

996+
def _register_simple_parameter(self, json_key):
997+
self._hyperparameters_registry[json_key] = self[json_key]
998+
return self._hyperparameters_registry[json_key]
999+
9901000
def _repr_html_(self):
9911001
res = "<pre>" + self.__class__.__name__ + "\n"
9921002
res += "\"enabled\": {}".format(self.enabled) + "\n"
9931003
for name, hyperparam_settings in self._hyperparameters_registry.items():
994-
res += "\"{}\": {}".format(name, hyperparam_settings._pretty_repr()) + "\n"
1004+
if isinstance(hyperparam_settings, HyperparameterSettings):
1005+
res += "\"{}\": {}".format(name, hyperparam_settings._pretty_repr()) + "\n"
1006+
else:
1007+
res += "\"{}\": {}".format(name, hyperparam_settings) + "\n"
9951008
res += "</pre>"
9961009
return res + "<details><pre>{}</pre></details>".format(self.__repr__())
9971010

@@ -1039,7 +1052,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
10391052
self.max_tree_depth = self._register_numerical_hyperparameter("max_tree_depth")
10401053
self.max_feature_prop = self._register_numerical_hyperparameter("max_feature_prop")
10411054
self.max_features = self._register_numerical_hyperparameter("max_features")
1042-
self.n_jobs = self._register_single_value_hyperparameter("n_jobs", accepted_types=[int])
1055+
self.n_jobs = self._register_simple_parameter("n_jobs")
10431056
self.selection_mode = self._register_single_category_hyperparameter("selection_mode", accepted_values=["auto", "sqrt", "log2", "number", "prop"])
10441057

10451058

@@ -1060,14 +1073,14 @@ def __init__(self, raw_settings, hyperparameter_search_params):
10601073
self.booster = self._register_categorical_hyperparameter("booster")
10611074
self.objective = self._register_categorical_hyperparameter("objective")
10621075
self.n_estimators = self._register_single_value_hyperparameter("n_estimators", accepted_types=[int])
1063-
self.nthread = self._register_single_value_hyperparameter("nthread", accepted_types=[int])
1076+
self.nthread = self._register_simple_parameter("nthread")
10641077
self.scale_pos_weight = self._register_single_value_hyperparameter("scale_pos_weight", accepted_types=[int, float])
10651078
self.base_score = self._register_single_value_hyperparameter("base_score", accepted_types=[int, float])
10661079
self.impute_missing = self._register_single_value_hyperparameter("impute_missing", accepted_types=[bool])
10671080
self.missing = self._register_single_value_hyperparameter("missing", accepted_types=[int, float])
10681081
self.cpu_tree_method = self._register_single_category_hyperparameter("cpu_tree_method", accepted_values=["auto", "exact", "approx", "hist"])
10691082
self.gpu_tree_method = self._register_single_category_hyperparameter("gpu_tree_method", accepted_values=["gpu_exact", "gpu_hist"])
1070-
self.enable_cuda = self._register_single_value_hyperparameter("enable_cuda", accepted_types=[bool])
1083+
self.enable_cuda = self._register_simple_parameter("enable_cuda")
10711084
self.seed = self._register_single_value_hyperparameter("seed", accepted_types=[int])
10721085
self.enable_early_stopping = self._register_single_value_hyperparameter("enable_early_stopping", accepted_types=[bool])
10731086
self.early_stopping_rounds = self._register_single_value_hyperparameter("early_stopping_rounds", accepted_types=[int])
@@ -1127,7 +1140,7 @@ class OLSSettings(PredictionAlgorithmSettings):
11271140

11281141
def __init__(self, raw_settings, hyperparameter_search_params):
11291142
super(OLSSettings, self).__init__(raw_settings, hyperparameter_search_params)
1130-
self.n_jobs = self._register_single_value_hyperparameter("n_jobs", accepted_types=[int])
1143+
self.n_jobs = self._register_simple_parameter("n_jobs")
11311144

11321145

11331146
class LARSSettings(PredictionAlgorithmSettings):
@@ -1148,7 +1161,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
11481161
self.l1_ratio = self._register_single_value_hyperparameter("l1_ratio", accepted_types=[int, float])
11491162
self.max_iter = self._register_single_value_hyperparameter("max_iter", accepted_types=[int])
11501163
self.tol = self._register_single_value_hyperparameter("tol", accepted_types=[int, float])
1151-
self.n_jobs = self._register_single_value_hyperparameter("n_jobs", accepted_types=[int])
1164+
self.n_jobs = self._register_simple_parameter("n_jobs")
11521165

11531166

11541167
class KNNSettings(PredictionAlgorithmSettings):
@@ -1231,10 +1244,10 @@ class MLLibDecisionTreeSettings(PredictionAlgorithmSettings):
12311244
def __init__(self, raw_settings, hyperparameter_search_params):
12321245
super(MLLibDecisionTreeSettings, self).__init__(raw_settings, hyperparameter_search_params)
12331246
self.max_depth = self._register_numerical_hyperparameter("max_depth")
1234-
self.cache_node_ids = self._register_single_value_hyperparameter("cache_node_ids", accepted_types=[bool])
1247+
self.cache_node_ids = self._register_simple_parameter("cache_node_ids")
12351248
self.checkpoint_interval = self._register_single_value_hyperparameter("checkpoint_interval", accepted_types=[int])
12361249
self.max_bins = self._register_single_value_hyperparameter("max_bins", accepted_types=[int])
1237-
self.max_memory_mb = self._register_single_value_hyperparameter("max_memory_mb", accepted_types=[int])
1250+
self.max_memory_mb = self._register_simple_parameter("max_memory_mb")
12381251
self.min_info_gain = self._register_single_value_hyperparameter("min_info_gain", accepted_types=[int, float])
12391252
self.min_instance_per_node = self._register_single_value_hyperparameter("min_instance_per_node", accepted_types=[int])
12401253

@@ -1247,11 +1260,11 @@ def __init__(self, raw_settings, hyperparameter_search_params):
12471260
self.max_depth = self._register_numerical_hyperparameter("max_depth")
12481261
self.num_trees = self._register_numerical_hyperparameter("num_trees")
12491262

1250-
self.cache_node_ids = self._register_single_value_hyperparameter("cache_node_ids", accepted_types=[bool])
1263+
self.cache_node_ids = self._register_simple_parameter("cache_node_ids")
12511264
self.checkpoint_interval = self._register_single_value_hyperparameter("checkpoint_interval", accepted_types=[int])
12521265
self.impurity = self._register_single_category_hyperparameter("impurity", accepted_values=["gini", "entropy", "variance"]) # TODO: distinguish between regression and classif
12531266
self.max_bins = self._register_single_value_hyperparameter("max_bins", accepted_types=[int])
1254-
self.max_memory_mb = self._register_single_value_hyperparameter("max_memory_mb", accepted_types=[int])
1267+
self.max_memory_mb = self._register_simple_parameter("max_memory_mb")
12551268
self.min_info_gain = self._register_single_value_hyperparameter("min_info_gain", accepted_types=[int, float])
12561269
self.min_instance_per_node = self._register_single_value_hyperparameter("min_instance_per_node", accepted_types=[int])
12571270
self.seed = self._register_single_value_hyperparameter("seed", accepted_types=[int])

0 commit comments

Comments
 (0)