Skip to content

Commit 04baf7d

Browse files
committed
Merge branch 'custom_strategies' into hyperparametertuning_custom_strategies
2 parents 1a4c439 + 8afc3d6 commit 04baf7d

File tree

4 files changed

+244
-22
lines changed

4 files changed

+244
-22
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/usr/bin/env python
2+
"""This is the minimal example from the README"""
3+
4+
import numpy
5+
import kernel_tuner
6+
from kernel_tuner import tune_kernel
7+
from kernel_tuner.file_utils import store_output_file, store_metadata_file
8+
9+
def tune():
10+
11+
kernel_string = """
12+
__global__ void vector_add(float *c, float *a, float *b, int n) {
13+
int i = blockIdx.x * block_size_x + threadIdx.x;
14+
if (i<n) {
15+
c[i] = a[i] + b[i];
16+
}
17+
}
18+
"""
19+
20+
size = 10000000
21+
22+
a = numpy.random.randn(size).astype(numpy.float32)
23+
b = numpy.random.randn(size).astype(numpy.float32)
24+
c = numpy.zeros_like(b)
25+
n = numpy.int32(size)
26+
27+
args = [c, a, b, n]
28+
29+
tune_params = dict()
30+
tune_params["block_size_x"] = [128+64*i for i in range(15)]
31+
32+
results, env = tune_kernel("vector_add", kernel_string, size, args, tune_params, strategy=kernel_tuner.strategies.minimize, verbose=True)
33+
34+
# Store the tuning results in an output file
35+
store_output_file("vector_add.json", results, tune_params)
36+
37+
# Store the metadata of this run
38+
store_metadata_file("vector_add-metadata.json")
39+
40+
return results
41+
42+
43+
if __name__ == "__main__":
44+
tune()

kernel_tuner/interface.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -632,29 +632,15 @@ def tune_kernel(
632632
if strategy in strategy_map:
633633
strategy = strategy_map[strategy]
634634
else:
635-
raise ValueError(f"Unkown strategy {strategy}, must be one of: {', '.join(list(strategy_map.keys()))}")
636-
637-
# make strategy_options into an Options object
638-
if tuning_options.strategy_options:
639-
if not isinstance(strategy_options, Options):
640-
tuning_options.strategy_options = Options(strategy_options)
641-
642-
# select strategy based on user options
643-
if "fraction" in tuning_options.strategy_options and not tuning_options.strategy == "random_sample":
644-
raise ValueError(
645-
'It is not possible to use fraction in combination with strategies other than "random_sample". '
646-
'Please set strategy="random_sample", when using "fraction" in strategy_options'
647-
)
648-
649-
# check if method is supported by the selected strategy
650-
if "method" in tuning_options.strategy_options:
651-
method = tuning_options.strategy_options.method
652-
if method not in strategy.supported_methods:
653-
raise ValueError("Method %s is not supported for strategy %s" % (method, tuning_options.strategy))
635+
# check for user-defined strategy
636+
if hasattr(strategy, "tune") and callable(strategy.tune):
637+
# user-defined strategy
638+
pass
639+
else:
640+
raise ValueError(f"Unkown strategy {strategy}, must be one of: {', '.join(list(strategy_map.keys()))}")
654641

655-
# if no strategy_options dict has been passed, create empty dictionary
656-
else:
657-
tuning_options.strategy_options = Options({})
642+
# ensure strategy_options is an Options object
643+
tuning_options.strategy_options = Options(strategy_options or {})
658644

659645
# if no strategy selected
660646
else:

kernel_tuner/strategies/wrapper.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Wrapper intended for user-defined custom optimization methods"""
2+
3+
from kernel_tuner import util
4+
from kernel_tuner.searchspace import Searchspace
5+
from kernel_tuner.strategies.common import CostFunc
6+
7+
8+
class OptAlgWrapper:
9+
"""Wrapper class for user-defined optimization algorithms"""
10+
11+
def __init__(self, optimizer, scaling=True):
12+
self.optimizer = optimizer
13+
self.scaling = scaling
14+
15+
16+
def tune(self, searchspace: Searchspace, runner, tuning_options):
17+
cost_func = CostFunc(searchspace, tuning_options, runner, scaling=self.scaling)
18+
19+
if self.scaling:
20+
# Initialize costfunc for scaling
21+
cost_func.get_bounds_x0_eps()
22+
23+
try:
24+
self.optimizer(cost_func)
25+
except util.StopCriterionReached as e:
26+
if tuning_options.verbose:
27+
print(e)
28+
29+
return cost_func.results

test/test_custom_optimizer.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
2+
### The following was generating using the LLaMEA prompt and OpenAI o1
3+
4+
import numpy as np
5+
6+
class HybridDELocalRefinement:
7+
"""
8+
A two-phase differential evolution with local refinement, intended for BBOB-type
9+
black box optimization problems in [-5,5]^dim.
10+
11+
One-line idea: A two-phase hybrid DE with local refinement that balances global
12+
exploration and local exploitation under a strict function evaluation budget.
13+
"""
14+
15+
def __init__(self, budget, dim):
16+
"""
17+
Initialize the optimizer with:
18+
- budget: total number of function evaluations allowed.
19+
- dim: dimensionality of the search space.
20+
"""
21+
self.budget = budget
22+
self.dim = dim
23+
# You can adjust these hyperparameters based on experimentation/tuning:
24+
self.population_size = min(50, 10 * dim) # Caps for extremely large dim
25+
self.F = 0.8 # Differential weight
26+
self.CR = 0.9 # Crossover probability
27+
self.local_search_freq = 10 # Local refinement frequency in generations
28+
29+
def __call__(self, func):
30+
"""
31+
Optimize the black box function `func` in [-5,5]^dim, using
32+
at most self.budget function evaluations.
33+
34+
Returns:
35+
best_params: np.ndarray representing the best parameters found
36+
best_value: float representing the best objective value found
37+
"""
38+
# Check if we have a non-positive budget
39+
if self.budget <= 0:
40+
raise ValueError("Budget must be a positive integer.")
41+
42+
# 1. Initialize population
43+
lower_bound, upper_bound = -5.0, 5.0
44+
pop = np.random.uniform(lower_bound, upper_bound, (self.population_size, self.dim))
45+
46+
# Evaluate initial population
47+
evaluations = 0
48+
fitness = np.empty(self.population_size)
49+
for i in range(self.population_size):
50+
fitness[i] = func(pop[i])
51+
evaluations += 1
52+
if evaluations >= self.budget:
53+
break
54+
55+
# Track best solution
56+
best_idx = np.argmin(fitness)
57+
best_params = pop[best_idx].copy()
58+
best_value = fitness[best_idx]
59+
60+
# 2. Main evolutionary loop
61+
gen = 0
62+
while evaluations < self.budget:
63+
gen += 1
64+
for i in range(self.population_size):
65+
# DE mutation: pick three distinct indices
66+
idxs = np.random.choice(self.population_size, 3, replace=False)
67+
a, b, c = pop[idxs]
68+
mutant = a + self.F * (b - c)
69+
70+
# Crossover
71+
trial = np.copy(pop[i])
72+
crossover_points = np.random.rand(self.dim) < self.CR
73+
trial[crossover_points] = mutant[crossover_points]
74+
75+
# Enforce bounds
76+
trial = np.clip(trial, lower_bound, upper_bound)
77+
78+
# Evaluate trial
79+
trial_fitness = func(trial)
80+
evaluations += 1
81+
if evaluations >= self.budget:
82+
# If out of budget, wrap up
83+
if trial_fitness < fitness[i]:
84+
pop[i] = trial
85+
fitness[i] = trial_fitness
86+
# Update global best
87+
if trial_fitness < best_value:
88+
best_value = trial_fitness
89+
best_params = trial.copy()
90+
break
91+
92+
# Selection
93+
if trial_fitness < fitness[i]:
94+
pop[i] = trial
95+
fitness[i] = trial_fitness
96+
# Update global best
97+
if trial_fitness < best_value:
98+
best_value = trial_fitness
99+
best_params = trial.copy()
100+
101+
# Periodically refine best solution with a small local neighborhood search
102+
if gen % self.local_search_freq == 0 and evaluations < self.budget:
103+
best_params, best_value, evaluations = self._local_refinement(
104+
func, best_params, best_value, evaluations, lower_bound, upper_bound
105+
)
106+
107+
if evaluations >= self.budget:
108+
break
109+
110+
return best_params, best_value
111+
112+
def _local_refinement(self, func, best_params, best_value, evaluations, lb, ub):
113+
"""
114+
Local refinement around the best solution found so far.
115+
Uses a quick 'perturb-and-accept' approach in a shrinking neighborhood.
116+
"""
117+
# Neighborhood size shrinks as the budget is consumed
118+
frac_budget_used = evaluations / self.budget
119+
step_size = 0.2 * (1.0 - frac_budget_used)
120+
121+
for _ in range(5): # 5 refinements each time
122+
if evaluations >= self.budget:
123+
break
124+
candidate = best_params + np.random.uniform(-step_size, step_size, self.dim)
125+
candidate = np.clip(candidate, lb, ub)
126+
cand_value = func(candidate)
127+
evaluations += 1
128+
if cand_value < best_value:
129+
best_value = cand_value
130+
best_params = candidate.copy()
131+
132+
return best_params, best_value, evaluations
133+
134+
135+
136+
137+
### Testing the Optimization Algorithm Wrapper in Kernel Tuner
138+
import os
139+
from kernel_tuner import tune_kernel
140+
from kernel_tuner.strategies.wrapper import OptAlgWrapper
141+
cache_filename = os.path.dirname(
142+
143+
os.path.realpath(__file__)) + "/test_cache_file.json"
144+
145+
from .test_runners import env
146+
147+
148+
def test_OptAlgWrapper(env):
149+
kernel_name, kernel_string, size, args, tune_params = env
150+
151+
# Instantiate LLaMAE optimization algorithm
152+
budget = int(15)
153+
dim = len(tune_params)
154+
optimizer = HybridDELocalRefinement(budget, dim)
155+
156+
# Wrap the algorithm class in the OptAlgWrapper
157+
# for use in Kernel Tuner
158+
strategy = OptAlgWrapper(optimizer)
159+
160+
# Call the tuner
161+
tune_kernel(kernel_name, kernel_string, size, args, tune_params,
162+
strategy=strategy, cache=cache_filename,
163+
simulation_mode=True, verbose=True)

0 commit comments

Comments
 (0)