Skip to content

Commit 6f48952

Browse files
Update with doctests
1 parent 8e91a5c commit 6f48952

File tree

1 file changed

+194
-60
lines changed

1 file changed

+194
-60
lines changed

genetic_algorithm/knapsack.py

Lines changed: 194 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,76 @@
1-
"""Did you know that Genetic Algorithms can be used to quickly approximate combinatorial optimization problems such as knapsack?"""
1+
"""Did you know that Genetic Algorithms can be used to quickly approximate
2+
combinatorial optimization problems such as knapsack?
3+
4+
Run doctests:
5+
python -m doctest -v ga_knapsack.py
6+
"""
27

38
import random
49
from dataclasses import dataclass
510

11+
# Keep module-level RNG deterministic for examples that rely on random,
12+
# but individual doctests re-seed locally as needed.
613
random.seed(42)
714

815
# =========================== Problem setup: Knapsack ===========================
916

10-
KNAPSACK_N_ITEMS = 42 # Number of items in the knapsack problem
11-
KNAPSACK_VALUE_RANGE = (10, 100) # Range of item values
12-
KNAPSACK_WEIGHT_RANGE = (5, 50) # Range of item weights
13-
KNAPSACK_CAPACITY_RATIO = 0.5 # Capacity as a fraction of total weight
14-
17+
KNAPSACK_N_ITEMS: int = 42 # Number of items in the knapsack problem
18+
KNAPSACK_VALUE_RANGE: tuple[int, int] = (10, 100) # Range of item values
19+
KNAPSACK_WEIGHT_RANGE: tuple[int, int] = (5, 50) # Range of item weights
20+
KNAPSACK_CAPACITY_RATIO: float = 0.5 # Capacity as a fraction of total weight
1521

1622
@dataclass
1723
class Item:
1824
value: int
1925
weight: int
2026

21-
2227
def generate_knapsack_instance(
2328
n_items: int,
2429
value_range: tuple[int, int],
2530
weight_range: tuple[int, int],
26-
capacity_ratio=float,
31+
capacity_ratio: float
2732
) -> tuple[list[Item], int]:
28-
"""Generates a random knapsack problem instance."""
33+
"""
34+
Generates a random knapsack problem instance.
35+
36+
Returns a tuple: (items, capacity), where items is a list of Item(value, weight)
37+
and capacity is an int computed as floor(capacity_ratio * total_weight).
38+
39+
Examples
40+
--------
41+
Use a tiny, deterministic instance to validate shape and capacity range:
42+
43+
>>> random.seed(0)
44+
>>> items, cap = generate_knapsack_instance(
45+
... n_items=3,
46+
... value_range=(5, 5),
47+
... weight_range=(10, 10),
48+
... capacity_ratio=0.5
49+
... )
50+
>>> len(items), cap
51+
(3, 15)
52+
>>> all(isinstance(it, Item) for it in items)
53+
True
54+
>>> [it.value for it in items], [it.weight for it in items]
55+
([5, 5, 5], [10, 10, 10])
56+
"""
2957
items = []
3058
for _ in range(n_items):
3159
value = random.randint(*value_range)
3260
weight = random.randint(*weight_range)
3361
items.append(Item(value=value, weight=weight))
34-
# We set capacity as a fraction of total weight
62+
# Capacity as a fraction of total weight
3563
capacity = int(sum(it.weight for it in items) * capacity_ratio)
3664
return items, capacity
3765

38-
66+
# Example instance (guarded by __main__ below for printing)
3967
items, capacity = generate_knapsack_instance(
4068
n_items=KNAPSACK_N_ITEMS,
4169
value_range=KNAPSACK_VALUE_RANGE,
4270
weight_range=KNAPSACK_WEIGHT_RANGE,
43-
capacity_ratio=KNAPSACK_CAPACITY_RATIO,
71+
capacity_ratio=KNAPSACK_CAPACITY_RATIO
4472
)
4573

46-
4774
# ============================== GA Representation ==============================
4875

4976
# HYPERPARAMETERS (For tuning the GA)
@@ -59,64 +86,175 @@ def generate_knapsack_instance(
5986

6087
Genome = list[int] # An index list where 1 means item is included, 0 means excluded
6188

62-
6389
def evaluate(genome: Genome, items: list[Item], capacity: int) -> tuple[int, int]:
64-
"""Evaluation function - calculates the fitness of each candidate based on total value and weight."""
90+
"""
91+
Calculates fitness (value) and weight of a candidate solution. If overweight,
92+
the returned value is penalized; weight is the actual summed weight.
93+
94+
Returns (value, weight).
95+
96+
Examples
97+
--------
98+
Feasible genome (no penalty):
99+
100+
>>> it = [Item(10, 4), Item(7, 3), Item(5, 2)]
101+
>>> genome = [1, 0, 1] # take items 0 and 2
102+
>>> evaluate(genome, it, capacity=7)
103+
(15, 6)
104+
105+
Overweight genome (penalty applies):
106+
Total value = 10+7+5 = 22, total weight = 9, capacity = 7, overflow = 2
107+
Penalized value = max(0, 22 - 2 * OVERWEIGHT_PENALTY_FACTOR) = 2
108+
109+
>>> genome = [1, 1, 1]
110+
>>> evaluate(genome, it, capacity=7)
111+
(2, 9)
112+
"""
65113
total_value = 0
66114
total_weight = 0
67115
for gene, item in zip(genome, items):
68116
if gene:
69117
total_value += item.value
70118
total_weight += item.weight
71119
if total_weight > capacity:
72-
# Penalize overweight solutions: return small value scaled by overflow
73-
overflow = total_weight - capacity
120+
overflow = (total_weight - capacity)
74121
total_value = max(0, total_value - overflow * OVERWEIGHT_PENALTY_FACTOR)
75122
return total_value, total_weight
76123

77124

78125
def random_genome(n: int) -> Genome:
79-
"""Generates a random genome of length n."""
80-
return [random.randint(0, 1) for _ in range(n)]
126+
"""
127+
Generates a random genome (list of 0/1) of length n.
81128
129+
Examples
130+
--------
131+
Check length and content are 0/1 bits:
132+
133+
>>> random.seed(123)
134+
>>> g = random_genome(5)
135+
>>> len(g), set(g).issubset({0, 1})
136+
(5, True)
137+
"""
138+
return [random.randint(0, 1) for _ in range(n)]
82139

83140
def selection(population: list[Genome], fitnesses: list[int], k: int) -> Genome:
84-
"""Performs tournament selection to choose genomes from the population.
141+
"""
142+
Performs tournament selection to choose a genome from the population.
143+
85144
Note that other selection strategies exist such as roulette wheel, rank-based, etc.
145+
146+
Examples
147+
--------
148+
Deterministic tournament with fixed seed (k=2):
149+
150+
>>> random.seed(1)
151+
>>> pop = [[0,0,0], [1,0,0], [1,1,0], [1,1,1]]
152+
>>> fits = [0, 5, 9, 7]
153+
>>> parent = selection(pop, fits, k=2)
154+
>>> parent in pop
155+
True
86156
"""
87157
contenders = random.sample(list(zip(population, fitnesses)), k)
88158
get_fitness = lambda x: x[1]
89159
return max(contenders, key=get_fitness)[0][:]
90160

91161

92162
def crossover(a: Genome, b: Genome, p_crossover: float) -> tuple[Genome, Genome]:
93-
"""Performs single-point crossover between two genomes.
94-
Note that other crossover strategies exist such as two-point crossover, uniform crossover, etc."""
163+
"""
164+
Performs single-point crossover between two genomes.
165+
If crossover does not occur (random > p_crossover) or genomes are too short,
166+
returns copies of the parents.
167+
168+
Note: other crossover strategies exist (two-point, uniform, etc.).
169+
170+
Examples
171+
--------
172+
Force crossover with p=1.0 and fixed RNG; verify lengths and bit content:
173+
174+
>>> random.seed(2)
175+
>>> a, b = [0,0,0,0], [1,1,1,1]
176+
>>> c1, c2 = crossover(a, b, p_crossover=1.0)
177+
>>> len(c1) == len(a) == len(c2) == len(b)
178+
True
179+
>>> set(c1).issubset({0,1}) and set(c2).issubset({0,1})
180+
True
181+
182+
No crossover if p=0.0:
183+
184+
>>> c1, c2 = crossover([0,0,0], [1,1,1], p_crossover=0.0)
185+
>>> c1, c2
186+
([0, 0, 0], [1, 1, 1])
187+
"""
95188
min_length = min(len(a), len(b))
96189
if random.random() > p_crossover or min_length < 2:
97190
return a[:], b[:]
98191
cutoff_point = random.randint(1, min_length - 1)
99192
return a[:cutoff_point] + b[cutoff_point:], b[:cutoff_point] + a[cutoff_point:]
100193

194+
def mutation(g: Genome, p_mutation: float) -> Genome:
195+
"""
196+
Performs bit-flip mutation on a genome. Each bit flips with probability p_mutation.
197+
198+
Note: other mutation strategies exist (swap, scramble, etc.).
199+
200+
Examples
201+
--------
202+
With probability 1.0, every bit flips:
101203
102-
def mutation(g: Genome, p_mutation: int) -> Genome:
103-
"""Performs bit-flip mutation on a genome.
104-
Note that other mutation strategies exist such as swap mutation, scramble mutation, etc.
204+
>>> mutation([0, 1, 1, 0], p_mutation=1.0)
205+
[1, 0, 0, 1]
206+
207+
With probability 0.0, nothing changes:
208+
209+
>>> mutation([0, 1, 1, 0], p_mutation=0.0)
210+
[0, 1, 1, 0]
105211
"""
106212
return [(1 - gene) if random.random() < p_mutation else gene for gene in g]
107213

108214

109215
def run_ga(
110216
items: list[Item],
111217
capacity: int,
112-
pop_size=POPULATION_SIZE,
113-
generations=GENERATIONS,
114-
p_crossover=CROSSOVER_PROBABILITY,
115-
p_mutation=MUTATION_PROBABILITY,
116-
tournament_k=TOURNAMENT_K,
117-
elitism=ELITISM,
218+
pop_size: int = POPULATION_SIZE,
219+
generations: int = GENERATIONS,
220+
p_crossover: float = CROSSOVER_PROBABILITY,
221+
p_mutation: float = MUTATION_PROBABILITY,
222+
tournament_k: int = TOURNAMENT_K,
223+
elitism: int = ELITISM,
118224
):
119-
"""Runs the genetic algorithm to solve the knapsack problem."""
225+
"""
226+
Runs the genetic algorithm to (approximately) solve the knapsack problem.
227+
228+
Returns a dict with keys:
229+
- 'best_genome' (Genome)
230+
- 'best_value' (int)
231+
- 'best_weight' (int)
232+
- 'capacity' (int)
233+
- 'best_history' (list[int])
234+
- 'avg_history' (list[float])
235+
236+
Examples
237+
--------
238+
Use a tiny instance and few generations to validate structure and lengths:
239+
240+
>>> random.seed(1234)
241+
>>> tiny_items = [Item(5,2), Item(6,3), Item(2,1), Item(7,4)]
242+
>>> cap = 5
243+
>>> out = run_ga(
244+
... tiny_items, cap,
245+
... pop_size=10, generations=5,
246+
... p_crossover=0.9, p_mutation=0.05,
247+
... tournament_k=2, elitism=1
248+
... )
249+
>>> sorted(out.keys())
250+
['avg_history', 'best_genome', 'best_history', 'best_value', 'best_weight', 'capacity']
251+
>>> len(out['best_history']) == 5 and len(out['avg_history']) == 5
252+
True
253+
>>> isinstance(out['best_genome'], list) and isinstance(out['best_value'], int)
254+
True
255+
>>> out['capacity'] == cap
256+
True
257+
"""
120258
n = len(items)
121259
population = [random_genome(n) for _ in range(pop_size)]
122260
best_history = [] # track best fitness per generation
@@ -138,10 +276,8 @@ def run_ga(
138276

139277
# Elitism
140278
get_fitness = lambda i: fitnesses[i]
141-
elite_indices = sorted(range(pop_size), key=get_fitness, reverse=True)[
142-
:elitism
143-
] # Sort the population by fitness and get the top `elitism` indices
144-
elites = [population[i][:] for i in elite_indices] # Make nepo babies
279+
elite_indices = sorted(range(pop_size), key=get_fitness, reverse=True)[:elitism]
280+
elites = [population[i][:] for i in elite_indices]
145281

146282
# New generation
147283
new_pop = elites[:]
@@ -165,27 +301,25 @@ def run_ga(
165301
"avg_history": avg_history,
166302
}
167303

168-
169-
result = run_ga(items, capacity)
170-
171-
best_items = [items[i] for i, bit in enumerate(result["best_genome"]) if bit == 1]
172-
173-
print(f"Knapsack capacity: {result['capacity']}")
174-
print(
175-
f"Best solution: value = {result['best_value']}, weight = {result['best_weight']}"
176-
)
177-
178-
# print("Items included in the best solution:", best_items)
179-
180-
# import matplotlib.pyplot as plt
181-
182-
# # Plot fitness curves
183-
# plt.figure()
184-
# plt.plot(result["best_history"], label="Best fitness")
185-
# plt.plot(result["avg_history"], label="Average fitness")
186-
# plt.title("GA on Knapsack: Fitness over Generations")
187-
# plt.xlabel("Generation")
188-
# plt.ylabel("Fitness")
189-
# plt.legend()
190-
# plt.tight_layout()
191-
# plt.show()
304+
# ================================ Script entry =================================
305+
306+
if __name__ == "__main__":
307+
result = run_ga(items, capacity)
308+
best_items = [items[i] for i, bit in enumerate(result["best_genome"]) if bit == 1]
309+
310+
print(f"Knapsack capacity: {result['capacity']}")
311+
print(f"Best solution: value = {result['best_value']}, weight = {result['best_weight']}")
312+
# Uncomment to inspect chosen items:
313+
# print("Items included in the best solution:", best_items)
314+
315+
# Optional: plot fitness curves
316+
# import matplotlib.pyplot as plt
317+
# plt.figure()
318+
# plt.plot(result["best_history"], label="Best fitness")
319+
# plt.plot(result["avg_history"], label="Average fitness")
320+
# plt.title("GA on Knapsack: Fitness over Generations")
321+
# plt.xlabel("Generation")
322+
# plt.ylabel("Fitness")
323+
# plt.legend()
324+
# plt.tight_layout()
325+
# plt.show()

0 commit comments

Comments
 (0)