Skip to content

Commit bad07f1

Browse files
authored
Merge PR #119 numerical hyperparam syntactic sugar
from enhancement/numerical-hyperparams
2 parents bc2828e + 30fe5ab commit bad07f1

File tree

1 file changed

+95
-58
lines changed

1 file changed

+95
-58
lines changed

dataikuapi/dss/ml.py

Lines changed: 95 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ def set_explicit_values(self, values):
793793
- the definition mode of the current numerical hyperparameter to "EXPLICIT"
794794
795795
:param values: the explicit list of numerical values considered for this hyperparameter in the search
796-
:type values: list of float | int
796+
:type values: list of float | list of int
797797
"""
798798
self.values = values
799799
self.definition_mode = "EXPLICIT"
@@ -810,7 +810,7 @@ def values(self):
810810
def values(self, values):
811811
"""
812812
:param values: the explicit list of numerical values considered for this hyperparameter in the search
813-
:type values: list of float | int
813+
:type values: list of float | list of int
814814
"""
815815
error_message = "Invalid values input type for hyperparameter " \
816816
"\"{}\": ".format(self.name) + \
@@ -861,15 +861,76 @@ def _set_range(self, min=None, max=None, nb_values=None):
861861
if nb_values is not None:
862862
self._algo_settings[self.name]["range"]["nbValues"] = nb_values
863863

864+
class RangeSettings(object):
865+
"""
866+
[Internal] Range of a numerical hyperparameter (points to the algorithm settings)
867+
Should not be used directly by end users of the API
868+
"""
869+
870+
def __init__(self, numerical_hyperparameter_settings):
871+
self._numerical_hyperparameter_settings = numerical_hyperparameter_settings
872+
self._range_dict = self._numerical_hyperparameter_settings._algo_settings[numerical_hyperparameter_settings.name]["range"]
873+
874+
def __repr__(self):
875+
return "RangeSettings(min={}, max={}, nb_values={})".format(self.min, self.max, self.nb_values)
876+
877+
@property
878+
def min(self):
879+
"""
880+
:return: the lower bound of the range for this hyperparameter
881+
:rtype: float | int
882+
"""
883+
return self._range_dict["min"]
884+
885+
@min.setter
886+
def min(self, value):
887+
"""
888+
:param value: the lower bound of the range for this hyperparameter
889+
:type value: float | int
890+
"""
891+
self._numerical_hyperparameter_settings._set_range(min=value)
892+
893+
@property
894+
def max(self):
895+
"""
896+
:return: the upper bound of the range for this hyperparameter
897+
:rtype: float | int
898+
"""
899+
return self._range_dict["max"]
900+
901+
@max.setter
902+
def max(self, value):
903+
"""
904+
:param value: the upper bound of the range for this hyperparameter
905+
:type value: float | int
906+
"""
907+
self._numerical_hyperparameter_settings._set_range(max=value)
908+
909+
@property
910+
def nb_values(self):
911+
"""
912+
:return: for grid-search ("GRID" strategy) only, the number of values between min and max to consider
913+
:rtype: int
914+
"""
915+
return self._range_dict["nbValues"]
916+
917+
@nb_values.setter
918+
def nb_values(self, value):
919+
"""
920+
:param value: for grid-search ("GRID" strategy) only, the number of values between min and max to consider
921+
:type value: int
922+
"""
923+
self._numerical_hyperparameter_settings._set_range(nb_values=value)
924+
864925
def set_range(self, min=None, max=None, nb_values=None):
865926
"""
866927
Sets both:
867-
- the Range parameters to search over for the current numerical hyperparameter
928+
- the range parameters to search over for the current numerical hyperparameter
868929
- the definition mode of the current numerical hyperparameter to "RANGE"
869930
870-
:param min: the lower bound of the Range for this hyperparameter
931+
:param min: the lower bound of the range for this hyperparameter
871932
:type min: float | int
872-
:param max: the upper bound of the Range for this hyperparameter
933+
:param max: the upper bound of the range for this hyperparameter
873934
:type max: float | int
874935
:param nb_values: for grid-search ("GRID" strategy) only, the number of values between min and max to consider
875936
:type nb_values: int
@@ -879,66 +940,29 @@ def set_range(self, min=None, max=None, nb_values=None):
879940

880941
@property
881942
def range(self):
882-
return Range(self)
943+
return NumericalHyperparameterSettings.RangeSettings(self)
883944

884945

885946
class Range(object):
947+
"""
948+
Range of a numerical hyperparameter (min, max, nb_values)
949+
Use this class to define explicitly the parameters of the range of a numerical hyperparameter
950+
"""
951+
952+
def _check_input(self, value):
953+
assert isinstance(value, (int, float)), "Invalid input type for Range: {}".format(type(value))
886954

887-
def __init__(self, numerical_hyperparameter_settings):
888-
self._numerical_hyperparameter_settings = numerical_hyperparameter_settings
889-
self._range_dict = self._numerical_hyperparameter_settings._algo_settings[numerical_hyperparameter_settings.name]["range"]
955+
def __init__(self, min, max, nb_values=None):
956+
self._check_input(min)
957+
self._check_input(max)
958+
assert min <= max, "Invalid Range: min must be lower than max"
959+
self.min = min
960+
self.max = max
961+
self.nb_values = nb_values
890962

891963
def __repr__(self):
892964
return "Range(min={}, max={}, nb_values={})".format(self.min, self.max, self.nb_values)
893965

894-
@property
895-
def min(self):
896-
"""
897-
:return: the lower bound of the Range for this hyperparameter
898-
:rtype: float | int
899-
"""
900-
return self._range_dict["min"]
901-
902-
@min.setter
903-
def min(self, value):
904-
"""
905-
:param value: the lower bound of the Range this hyperparameter
906-
:type value: float | int
907-
"""
908-
self._numerical_hyperparameter_settings._set_range(min=value)
909-
910-
@property
911-
def max(self):
912-
"""
913-
:return: the upper bound of the Range this hyperparameter
914-
:rtype: float | int
915-
"""
916-
return self._range_dict["max"]
917-
918-
@max.setter
919-
def max(self, value):
920-
"""
921-
:param value: the upper bound of the Range for this hyperparameter
922-
:type value: float | int
923-
"""
924-
self._numerical_hyperparameter_settings._set_range(max=value)
925-
926-
@property
927-
def nb_values(self):
928-
"""
929-
:return: for grid-search ("GRID" strategy) only, the number of values between min and max to consider
930-
:rtype: int
931-
"""
932-
return self._range_dict["nbValues"]
933-
934-
@nb_values.setter
935-
def nb_values(self, value):
936-
"""
937-
:param value: for grid-search ("GRID" strategy) only, the number of values between min and max to consider
938-
:type value: int
939-
"""
940-
self._numerical_hyperparameter_settings._set_range(nb_values=value)
941-
942966

943967
class CategoricalHyperparameterSettings(HyperparameterSettings):
944968

@@ -1097,7 +1121,20 @@ def __setattr__(self, attr_name, value):
10971121
elif isinstance(target, CategoricalHyperparameterSettings):
10981122
target.set_values(value)
10991123
elif isinstance(target, NumericalHyperparameterSettings):
1100-
raise Exception("Invalid assignment of a NumericalHyperparameterSettings object")
1124+
if isinstance(value, list):
1125+
# algo.hyperparam = [x, y, z]
1126+
target.set_explicit_values(values=value)
1127+
elif isinstance(value, Range):
1128+
# algo.hyperparam = Range(min=x, max=y, nb_values=z)
1129+
target.set_range(min=value.min, max=value.max, nb_values=value.nb_values)
1130+
elif isinstance(value, NumericalHyperparameterSettings):
1131+
# algo.hyperparam = other_algo.other_hyperparam
1132+
target.set_range(min=value.range.min, max=value.range.max, nb_values=value.range.nb_values)
1133+
target.set_explicit_values(values=value.values.copy())
1134+
target.definition_mode = value.definition_mode
1135+
else:
1136+
raise TypeError(("Invalid type for NumericalHyperparameterSettings {}\n" +
1137+
"Expecting either list, Range or NumericalHyperparameterSettings").format(attr_name))
11011138
else:
11021139
# simple parameter
11031140
assert isinstance(value, type(target)), "Invalid type {} for parameter {}: expected {}".format(type(value), attr_name, type(target))

0 commit comments

Comments
 (0)