Skip to content

Commit 7e4f38c

Browse files
committed
Improvements to how non-numeric configurations are handled
1 parent e0aaf24 commit 7e4f38c

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

kernel_tuner/strategies/common.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,6 @@ def __init__(
130130
def __call__(self, x, check_restrictions=True):
131131
"""Cost function used by almost all strategies."""
132132
self.runner.last_strategy_time = 1000 * (perf_counter() - self.runner.last_strategy_start_time)
133-
if self.encode_non_numeric and not self.scaling:
134-
x_numeric = self.params_to_encoded(x)
135-
else:
136-
x_numeric = x
137133

138134
# error value to return for numeric optimizers that need a numerical value
139135
logging.debug("_cost_func called")
@@ -145,9 +141,7 @@ def __call__(self, x, check_restrictions=True):
145141
# snap values in x to nearest actual value for each parameter, unscale x if needed
146142
if self.snap:
147143
if self.scaling:
148-
params = unscale_and_snap_to_nearest(x_numeric, self.searchspace.tune_params, self.tuning_options.eps)
149-
if self.encode_non_numeric and not self.scaling:
150-
params = self.encoded_to_params(params)
144+
params = unscale_and_snap_to_nearest(x, self.searchspace.tune_params, self.tuning_options.eps)
151145
else:
152146
params = snap_to_nearest_config(x, self.searchspace.tune_params)
153147
else:
@@ -165,10 +159,8 @@ def __call__(self, x, check_restrictions=True):
165159

166160
if "constraint_aware" in self.tuning_options.strategy_options and self.tuning_options.strategy_options["constraint_aware"]:
167161
# attempt to repair
168-
new_params = unscale_and_snap_to_nearest_valid(x_numeric, params, self.searchspace, self.tuning_options.eps)
162+
new_params = unscale_and_snap_to_nearest_valid(x, params, self.searchspace, self.tuning_options.eps)
169163
if new_params:
170-
if self.encode_non_numeric:
171-
new_params = self.encoded_to_params(new_params)
172164
params = new_params
173165
legal = True
174166
x_int = ",".join([str(i) for i in params])
@@ -254,8 +246,12 @@ def get_bounds_x0_eps(self):
254246
def get_bounds(self):
255247
"""Create a bounds array from the tunable parameters."""
256248
bounds = []
257-
for values in self.encoded_params_values if self.encode_non_numeric else self.searchspace.params_values:
258-
bounds.append((min(values), max(values)))
249+
for values in self.searchspace.params_values:
250+
try:
251+
bounds.append((min(values), max(values)))
252+
except TypeError:
253+
# if values are not numbers, use the first and last value as bounds
254+
bounds.append((values[0], values[-1]))
259255
return bounds
260256

261257
def encoded_to_params(self, config):

0 commit comments

Comments
 (0)