Skip to content

Commit 54010b4

Browse files
introduced repair technique in genetic algorithm
1 parent ea29129 commit 54010b4

File tree

4 files changed

+125
-58
lines changed

4 files changed

+125
-58
lines changed

kernel_tuner/strategies/genetic_algorithm.py

Lines changed: 95 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
_options = dict(
1212
popsize=("population size", 20),
1313
maxiter=("maximum number of generations", 100),
14+
constraint_aware=("constraint-aware optimization (True/False)", False),
1415
method=("crossover method to use, choose any from single_point, two_point, uniform, disruptive_uniform", "uniform"),
1516
mutation_chance=("chance to mutate is 1 in mutation_chance", 10),
1617
)
@@ -19,13 +20,15 @@
1920
def tune(searchspace: Searchspace, runner, tuning_options):
2021

2122
options = tuning_options.strategy_options
22-
pop_size, generations, method, mutation_chance = common.get_options(options, _options)
23+
pop_size, generations, constraint_aware, method, mutation_chance = common.get_options(options, _options)
2324
crossover = supported_methods[method]
2425

26+
GA = GeneticAlgorithm(pop_size, searchspace, constraint_aware, method, mutation_chance)
27+
2528
best_score = 1e20
2629
cost_func = CostFunc(searchspace, tuning_options, runner)
2730

28-
population = list(list(p) for p in searchspace.get_random_sample(pop_size))
31+
population = GA.generate_population()
2932

3033
for generation in range(generations):
3134

@@ -51,18 +54,19 @@ def tune(searchspace: Searchspace, runner, tuning_options):
5154
if tuning_options.verbose:
5255
print("Generation %d, best_score %f" % (generation, best_score))
5356

57+
# build new population for next generation
5458
population = []
5559

5660
# crossover and mutate
5761
while len(population) < pop_size:
58-
dna1, dna2 = weighted_choice(weighted_population, 2)
62+
dna1, dna2 = GA.weighted_choice(weighted_population, 2)
5963

60-
children = crossover(dna1, dna2)
64+
children = GA.crossover(dna1, dna2)
6165

6266
for child in children:
63-
child = mutate(child, mutation_chance, searchspace)
67+
child = GA.mutate(child)
6468

65-
if child not in population and searchspace.is_param_config_valid(tuple(child)):
69+
if child not in population:
6670
population.append(child)
6771

6872
if len(population) >= pop_size:
@@ -75,57 +79,94 @@ def tune(searchspace: Searchspace, runner, tuning_options):
7579

7680
tune.__doc__ = common.get_strategy_docstring("Genetic Algorithm", _options)
7781

78-
79-
def weighted_choice(population, n):
80-
"""Randomly select n unique individuals from a weighted population, fitness determines probability of being selected."""
81-
82-
def random_index_betavariate(pop_size):
83-
# has a higher probability of returning index of item at the head of the list
84-
alpha = 1
85-
beta = 2.5
86-
return int(random.betavariate(alpha, beta) * pop_size)
87-
88-
def random_index_weighted(pop_size):
89-
"""Use weights to increase probability of selection."""
90-
weights = [w for _, w in population]
91-
# invert because lower is better
92-
inverted_weights = [1.0 / w for w in weights]
93-
prefix_sum = np.cumsum(inverted_weights)
94-
total_weight = sum(inverted_weights)
95-
randf = random.random() * total_weight
96-
# return first index of prefix_sum larger than random number
97-
return next(i for i, v in enumerate(prefix_sum) if v > randf)
98-
99-
random_index = random_index_betavariate
100-
101-
indices = [random_index(len(population)) for _ in range(n)]
102-
chosen = []
103-
for ind in indices:
104-
while ind in chosen:
105-
ind = random_index(len(population))
106-
chosen.append(ind)
107-
108-
return [population[ind][0] for ind in chosen]
109-
110-
111-
def mutate(dna, mutation_chance, searchspace: Searchspace, cache=True):
112-
"""Mutate DNA with 1/mutation_chance chance."""
113-
# this is actually a neighbors problem with Hamming distance, choose randomly from returned searchspace list
114-
if int(random.random() * mutation_chance) == 0:
115-
if cache:
116-
neighbors = searchspace.get_neighbors(tuple(dna), neighbor_method="Hamming")
117-
else:
118-
neighbors = searchspace.get_neighbors_no_cache(tuple(dna), neighbor_method="Hamming")
119-
if len(neighbors) > 0:
120-
return list(random.choice(neighbors))
121-
return dna
82+
class GeneticAlgorithm:
83+
84+
def __init__(self, pop_size, searchspace, constraint_aware=False, method="uniform", mutation_chance=10):
85+
self.pop_size = pop_size
86+
self.searchspace = searchspace
87+
self.constraint_aware = constraint_aware
88+
self.crossover_method = supported_methods[method]
89+
self.mutation_chance = mutation_chance
90+
91+
def generate_population(self):
92+
""" Constraint-aware population creation method """
93+
return list(list(p) for p in self.searchspace.get_random_sample(self.pop_size))
94+
95+
def crossover(self, dna1, dna2):
96+
""" Apply selected crossover method, repair dna if constraint-aware """
97+
dna1, dna2 = self.crossover_method(dna1, dna2)
98+
if self.constraint_aware:
99+
return self.repair(dna1), self.repair(dna2)
100+
return dna1, dna2
101+
102+
def weighted_choice(self, population, n):
103+
"""Randomly select n unique individuals from a weighted population, fitness determines probability of being selected."""
104+
105+
def random_index_betavariate(pop_size):
106+
# has a higher probability of returning index of item at the head of the list
107+
alpha = 1
108+
beta = 2.5
109+
return int(random.betavariate(alpha, beta) * pop_size)
110+
111+
def random_index_weighted(pop_size):
112+
"""Use weights to increase probability of selection."""
113+
weights = [w for _, w in population]
114+
# invert because lower is better
115+
inverted_weights = [1.0 / w for w in weights]
116+
prefix_sum = np.cumsum(inverted_weights)
117+
total_weight = sum(inverted_weights)
118+
randf = random.random() * total_weight
119+
# return first index of prefix_sum larger than random number
120+
return next(i for i, v in enumerate(prefix_sum) if v > randf)
121+
122+
random_index = random_index_betavariate
123+
124+
indices = [random_index(len(population)) for _ in range(n)]
125+
chosen = []
126+
for ind in indices:
127+
while ind in chosen:
128+
ind = random_index(len(population))
129+
chosen.append(ind)
130+
131+
return [population[ind][0] for ind in chosen]
132+
133+
134+
def mutate(self, dna, cache=False):
135+
"""Mutate DNA with 1/mutation_chance chance."""
136+
# this is actually a neighbors problem with Hamming distance, choose randomly from returned searchspace list
137+
if int(random.random() * self.mutation_chance) == 0:
138+
if cache:
139+
neighbors = self.searchspace.get_neighbors(tuple(dna), neighbor_method="Hamming")
140+
else:
141+
neighbors = self.searchspace.get_neighbors_no_cache(tuple(dna), neighbor_method="Hamming")
142+
if len(neighbors) > 0:
143+
return list(random.choice(neighbors))
144+
return dna
145+
146+
147+
def repair(self, dna):
148+
""" It is possible that crossover methods yield a configuration that is not valid. """
149+
if not self.searchspace.is_param_config_valid(tuple(dna)):
150+
# dna is not valid, try to repair it
151+
# search for valid configurations neighboring this config
152+
# start from strictly-adjacent to increasingly allowing more neighbors
153+
for neighbor_method in ["strictly-adjacent", "adjacent", "Hamming"]:
154+
neighbors = self.searchspace.get_neighbors_no_cache(tuple(dna), neighbor_method=neighbor_method)
155+
156+
# if we have found valid neighboring configurations, select one at random
157+
if len(neighbors) > 0:
158+
new_dna = list(random.choice(neighbors))
159+
print(f"GA crossover resulted in invalid config {dna=}, repaired dna to {new_dna=}")
160+
return new_dna
161+
162+
return dna
122163

123164

124165
def single_point_crossover(dna1, dna2):
125166
"""Crossover dna1 and dna2 at a random index."""
126167
# check if you can do the crossovers using the neighbor index: check which valid parameter configuration is closest to the crossover, probably best to use "adjacent" as it is least strict?
127168
pos = int(random.random() * (len(dna1)))
128-
return (dna1[:pos] + dna2[pos:], dna2[:pos] + dna1[pos:])
169+
return dna1[:pos] + dna2[pos:], dna2[:pos] + dna1[pos:]
129170

130171

131172
def two_point_crossover(dna1, dna2):
@@ -137,7 +178,7 @@ def two_point_crossover(dna1, dna2):
137178
pos1, pos2 = sorted(random.sample(list(range(start, end)), 2))
138179
child1 = dna1[:pos1] + dna2[pos1:pos2] + dna1[pos2:]
139180
child2 = dna2[:pos1] + dna1[pos1:pos2] + dna2[pos2:]
140-
return (child1, child2)
181+
return child1, child2
141182

142183

143184
def uniform_crossover(dna1, dna2):
@@ -168,7 +209,7 @@ def disruptive_uniform_crossover(dna1, dna2):
168209
child1[ind] = dna2[ind]
169210
child2[ind] = dna1[ind]
170211
swaps += 1
171-
return (child1, child2)
212+
return child1, child2
172213

173214

174215
supported_methods = {
@@ -177,3 +218,4 @@ def disruptive_uniform_crossover(dna1, dna2):
177218
"uniform": uniform_crossover,
178219
"disruptive_uniform": disruptive_uniform_crossover,
179220
}
221+

kernel_tuner/strategies/greedy_ils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""A simple greedy iterative local search algorithm for parameter search."""
2+
import random
23
from kernel_tuner import util
34
from kernel_tuner.searchspace import Searchspace
45
from kernel_tuner.strategies import common
56
from kernel_tuner.strategies.common import CostFunc
6-
from kernel_tuner.strategies.genetic_algorithm import mutate
77
from kernel_tuner.strategies.hillclimbers import base_hillclimb
88

99
_options = dict(neighbor=("Method for selecting neighboring nodes, choose from Hamming or adjacent", "Hamming"),
@@ -58,9 +58,14 @@ def tune(searchspace: Searchspace, runner, tuning_options):
5858

5959
tune.__doc__ = common.get_strategy_docstring("Greedy Iterative Local Search (ILS)", _options)
6060

61+
def mutate(indiv, searchspace: Searchspace):
62+
neighbors = searchspace.get_neighbors_no_cache(tuple(indiv), neighbor_method="Hamming")
63+
return list(random.choice(neighbors))
64+
65+
6166
def random_walk(indiv, permutation_size, no_improve, last_improve, searchspace: Searchspace):
6267
if last_improve >= no_improve:
6368
return searchspace.get_random_sample(1)[0]
6469
for _ in range(permutation_size):
65-
indiv = mutate(indiv, 0, searchspace, cache=False)
70+
indiv = mutate(indiv, searchspace)
6671
return indiv

test/strategies/test_genetic_algorithm.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ def test_weighted_choice():
1414
pop = searchspace.get_random_sample(pop_size)
1515
weighted_pop = [[p, i] for i, p in enumerate(pop)]
1616

17-
result = ga.weighted_choice(weighted_pop, 1)
17+
GA = ga.GeneticAlgorithm(pop_size, searchspace)
18+
19+
result = GA.weighted_choice(weighted_pop, 1)
1820
assert result[0] in pop
1921

20-
result = ga.weighted_choice(weighted_pop, 2)
22+
result = GA.weighted_choice(weighted_pop, 2)
2123
print(result)
2224
assert result[0] in pop
2325
assert result[1] in pop
@@ -43,7 +45,9 @@ def test_random_population():
4345
def test_mutate():
4446
pop = searchspace.get_random_sample(1)
4547

46-
mutant = ga.mutate(pop[0], 10, searchspace)
48+
GA = ga.GeneticAlgorithm(1, searchspace)
49+
50+
mutant = GA.mutate(pop[0])
4751
assert len(pop[0]) == len(mutant)
4852
assert mutant[0] in tune_params["x"]
4953
assert mutant[1] in tune_params["y"]

test/test_runners.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,22 @@ def test_diff_evo(env):
140140
assert len(result) > 0
141141

142142

143+
def test_constraint_aware_GA(env):
144+
options = dict(method="uniform",
145+
constraint_aware=True,
146+
popsize=5,
147+
maxiter=2,
148+
mutation_chance=10,
149+
max_fevals=10)
150+
result, _ = tune_kernel(*env,
151+
strategy="genetic_algorithm",
152+
strategy_options=options,
153+
verbose=True,
154+
cache=cache_filename,
155+
simulation_mode=True)
156+
assert len(result) > 0
157+
158+
143159
@skip_if_no_pycuda
144160
def test_time_keeping(env):
145161
kernel_name, kernel_string, size, args, tune_params = env

0 commit comments

Comments
 (0)