Skip to content

Commit e0aaf24

Browse files
committed
Improvements to how non-numeric configurations are handled
1 parent 7b5cd29 commit e0aaf24

File tree

1 file changed

+33
-21
lines changed

1 file changed

+33
-21
lines changed

kernel_tuner/strategies/common.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def get_options(strategy_options, options, unsupported=None):
6262
return [strategy_options.get(opt, default) for opt, (_, default) in options.items()]
6363

6464

65+
def is_number(value) -> bool:
66+
"""Check if a value is a real number (false on booleans and complex numbers)."""
67+
return isinstance(value, numbers.Real) and not isinstance(value, bool)
68+
69+
6570
class CostFunc:
6671
"""Class encapsulating the CostFunc method."""
6772

@@ -73,7 +78,7 @@ def __init__(
7378
*,
7479
scaling=False,
7580
snap=True,
76-
encode_non_numeric=False,
81+
encode_non_numeric=None,
7782
return_invalid=False,
7883
return_raw=None,
7984
):
@@ -85,35 +90,36 @@ def __init__(
8590
runner: the runner to use.
8691
scaling: whether to internally scale parameter values. Defaults to False.
8792
snap: whether to snap given configurations to their closests equivalent in the space. Defaults to True.
88-
encode_non_numeric: whether to externally encode non-numeric parameter values. Defaults to False.
93+
encode_non_numeric: whether to encode non-numeric parameter values. Defaults to None, meaning it is applied when necessary.
8994
return_invalid: whether to return the util.ErrorConfig of an invalid configuration. Defaults to False.
9095
return_raw: returns (result, results[raw]). Key inferred from objective if set to True. Defaults to None.
9196
"""
92-
self.runner = runner
93-
self.snap = snap
94-
self.scaling = scaling
95-
self.encode_non_numeric = encode_non_numeric
96-
self.return_invalid = return_invalid
97-
self.return_raw = return_raw
98-
if return_raw is True:
99-
self.return_raw = f"{tuning_options['objective']}s"
10097
self.searchspace = searchspace
10198
self.tuning_options = tuning_options
10299
if isinstance(self.tuning_options, dict):
103100
self.tuning_options["max_fevals"] = min(
104101
tuning_options["max_fevals"] if "max_fevals" in tuning_options else np.inf, searchspace.size
105102
)
103+
self.runner = runner
104+
self.scaling = scaling
105+
self.snap = snap
106+
self.encode_non_numeric = encode_non_numeric if encode_non_numeric is not None else not all([all(is_number(v) for v in param_values) for param_values in self.searchspace.params_values])
107+
self.return_invalid = return_invalid
108+
self.return_raw = return_raw
109+
if return_raw is True:
110+
self.return_raw = f"{tuning_options['objective']}s"
106111
self.results = []
107112
self.budget_spent_fraction = 0.0
108113

109114
# if enabled, encode non-numeric parameter values as a numeric value
115+
# NOTE careful, this shouldn't conflict with Searchspace tensorspace
110116
if self.encode_non_numeric:
111117
self._map_param_to_encoded = {}
112118
self._map_encoded_to_param = {}
113119
self.encoded_params_values = []
114120
for i, param_values in enumerate(self.searchspace.params_values):
115121
encoded_values = param_values
116-
if not all(isinstance(v, numbers.Real) for v in param_values):
122+
if not all(is_number(v) for v in param_values):
117123
encoded_values = np.arange(
118124
len(param_values)
119125
) # NOTE when changing this, adjust the rounding in encoded_to_params
@@ -124,8 +130,10 @@ def __init__(
124130
def __call__(self, x, check_restrictions=True):
125131
"""Cost function used by almost all strategies."""
126132
self.runner.last_strategy_time = 1000 * (perf_counter() - self.runner.last_strategy_start_time)
127-
if self.encode_non_numeric:
128-
x = self.encoded_to_params(x)
133+
if self.encode_non_numeric and not self.scaling:
134+
x_numeric = self.params_to_encoded(x)
135+
else:
136+
x_numeric = x
129137

130138
# error value to return for numeric optimizers that need a numerical value
131139
logging.debug("_cost_func called")
@@ -137,7 +145,9 @@ def __call__(self, x, check_restrictions=True):
137145
# snap values in x to nearest actual value for each parameter, unscale x if needed
138146
if self.snap:
139147
if self.scaling:
140-
params = unscale_and_snap_to_nearest(x, self.searchspace.tune_params, self.tuning_options.eps)
148+
params = unscale_and_snap_to_nearest(x_numeric, self.searchspace.tune_params, self.tuning_options.eps)
149+
if self.encode_non_numeric and not self.scaling:
150+
params = self.encoded_to_params(params)
141151
else:
142152
params = snap_to_nearest_config(x, self.searchspace.tune_params)
143153
else:
@@ -155,8 +165,10 @@ def __call__(self, x, check_restrictions=True):
155165

156166
if "constraint_aware" in self.tuning_options.strategy_options and self.tuning_options.strategy_options["constraint_aware"]:
157167
# attempt to repair
158-
new_params = unscale_and_snap_to_nearest_valid(x, params, self.searchspace, self.tuning_options.eps)
168+
new_params = unscale_and_snap_to_nearest_valid(x_numeric, params, self.searchspace, self.tuning_options.eps)
159169
if new_params:
170+
if self.encode_non_numeric:
171+
new_params = self.encoded_to_params(new_params)
160172
params = new_params
161173
legal = True
162174
x_int = ",".join([str(i) for i in params])
@@ -209,6 +221,7 @@ def get_bounds_x0_eps(self):
209221

210222
if "x0" in self.tuning_options.strategy_options:
211223
x0 = self.tuning_options.strategy_options.x0
224+
assert isinstance(x0, (tuple, list)) and len(x0) == len(values), f"Invalid x0: {x0}, expected number of parameters of `tune_params` to match ({len(values)})"
212225
else:
213226
x0 = None
214227

@@ -242,11 +255,7 @@ def get_bounds(self):
242255
"""Create a bounds array from the tunable parameters."""
243256
bounds = []
244257
for values in self.encoded_params_values if self.encode_non_numeric else self.searchspace.params_values:
245-
try:
246-
bounds.append((min(values), max(values)))
247-
except TypeError:
248-
# if values are not numeric, use the first and last value as bounds
249-
bounds.append((values[0], values[-1]))
258+
bounds.append((min(values), max(values)))
250259
return bounds
251260

252261
def encoded_to_params(self, config):
@@ -277,7 +286,10 @@ def params_to_encoded(self, config):
277286
raise ValueError("'encode_non_numeric' must be set to true to use this function.")
278287
encoded = []
279288
for i, v in enumerate(config):
280-
encoded.append(self._map_param_to_encoded[i][v] if i in self._map_param_to_encoded else v)
289+
try:
290+
encoded.append(self._map_param_to_encoded[i][v] if i in self._map_param_to_encoded else v)
291+
except KeyError:
292+
raise KeyError(f"{config} parameter value {v} not found in {self._map_param_to_encoded} for parameter {i}.")
281293
assert len(encoded) == len(config)
282294
return encoded
283295

0 commit comments

Comments
 (0)