Skip to content

Commit 46b1a14

Browse files
authored
Merge pull request #34 from Axelrod-Python/add-ability-to-sample-opponents
Add ability to sample a number of opps
2 parents d5de5aa + 27f0aa7 commit 46b1a14

File tree

2 files changed

+118
-3
lines changed

2 files changed

+118
-3
lines changed

src/axelrod_dojo/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,19 @@ def crossover(self, other):
142142

143143
def score_params(params, objective,
144144
opponents_information,
145-
weights=None):
145+
weights=None, sample_count=None):
146146
"""
147147
Return the overall mean score of a Params instance.
148148
"""
149149
scores_for_all_opponents = []
150150
player = params.player()
151151

152+
if sample_count is not None:
153+
indices = np.random.choice(len(opponents_information), sample_count)
154+
opponents_information = [opponents_information[i] for i in indices]
155+
if weights is not None:
156+
weights = [weights[i] for i in indices]
157+
152158
for strategy, init_kwargs in opponents_information:
153159
player.reset()
154160
opponent = strategy(**init_kwargs)
@@ -164,7 +170,8 @@ def score_params(params, objective,
164170
class Population(object):
165171
"""Population class that implements the evolutionary algorithm."""
166172
def __init__(self, params_class, params_args, size, objective, output_filename,
167-
bottleneck=None, opponents=None, processes=1, weights=None):
173+
bottleneck=None, opponents=None, processes=1, weights=None,
174+
sample_count=None):
168175
self.params_class = params_class
169176
self.bottleneck = bottleneck
170177
if processes == 0:
@@ -187,13 +194,15 @@ def __init__(self, params_class, params_args, size, objective, output_filename,
187194
self.params_args = params_args
188195
self.population = [params_class(*params_args) for _ in range(self.size)]
189196
self.weights = weights
197+
self.sample_count = sample_count
190198

191199
def score_all(self):
192200
starmap_params = zip(
193201
self.population,
194202
repeat(self.objective),
195203
repeat(self.opponents_information),
196-
repeat(self.weights))
204+
repeat(self.weights),
205+
repeat(self.sample_count))
197206
results = self.pool.starmap(score_params, starmap_params)
198207
return results
199208

tests/integration/test_fsm.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,112 @@ def test_score_with_weights(self):
113113

114114
self.assertEqual(best[0].__repr__(), best_params)
115115

116+
def test_score_with_sample_count(self):
117+
name = "score"
118+
turns = 10
119+
noise = 0
120+
repetitions = 5
121+
num_states = 2
122+
mutation_rate = .1
123+
opponents = [s() for s in axl.demo_strategies]
124+
size = 10
125+
126+
objective = dojo.prepare_objective(name=name,
127+
turns=turns,
128+
noise=noise,
129+
repetitions=repetitions)
130+
131+
population = dojo.Population(params_class=dojo.FSMParams,
132+
params_args=(num_states, mutation_rate),
133+
size=size,
134+
objective=objective,
135+
output_filename=self.temporary_file.name,
136+
opponents=opponents,
137+
sample_count=2, # Randomly sample 2 opponents at each step
138+
bottleneck=2,
139+
processes=1)
140+
141+
generations = 4
142+
axl.seed(0)
143+
population.run(generations)
144+
self.assertEqual(population.generation, 4)
145+
146+
# Manually read from tempo file to find best strategy
147+
best_score, best_params = 0, None
148+
with open(self.temporary_file.name, "r") as f:
149+
reader = csv.reader(f)
150+
for row in reader:
151+
_, mean_score, sd_score, max_score, arg_max = row
152+
if float(max_score) > best_score:
153+
best_score = float(max_score)
154+
best_params = arg_max
155+
156+
# Test the load params function
157+
for num in range(1, 4 + 1):
158+
best = dojo.load_params(params_class=dojo.FSMParams,
159+
filename=self.temporary_file.name,
160+
num=num)
161+
self.assertEqual(len(best), num)
162+
163+
for parameters in best:
164+
self.assertIsInstance(parameters, dojo.FSMParams)
165+
166+
self.assertEqual(best[0].__repr__(), best_params)
167+
168+
169+
def test_score_with_sample_count_and_weights(self):
170+
name = "score"
171+
turns = 10
172+
noise = 0
173+
repetitions = 5
174+
num_states = 2
175+
mutation_rate = .1
176+
opponents = [s() for s in axl.demo_strategies]
177+
size = 10
178+
179+
objective = dojo.prepare_objective(name=name,
180+
turns=turns,
181+
noise=noise,
182+
repetitions=repetitions)
183+
184+
population = dojo.Population(params_class=dojo.FSMParams,
185+
params_args=(num_states, mutation_rate),
186+
size=size,
187+
objective=objective,
188+
output_filename=self.temporary_file.name,
189+
opponents=opponents,
190+
sample_count=2, # Randomly sample 2 opponents at each step
191+
weights=[5, 1, 1, 1, 1],
192+
bottleneck=2,
193+
processes=1)
194+
195+
generations = 4
196+
axl.seed(0)
197+
population.run(generations)
198+
self.assertEqual(population.generation, 4)
199+
200+
# Manually read from tempo file to find best strategy
201+
best_score, best_params = 0, None
202+
with open(self.temporary_file.name, "r") as f:
203+
reader = csv.reader(f)
204+
for row in reader:
205+
_, mean_score, sd_score, max_score, arg_max = row
206+
if float(max_score) > best_score:
207+
best_score = float(max_score)
208+
best_params = arg_max
209+
210+
# Test the load params function
211+
for num in range(1, 4 + 1):
212+
best = dojo.load_params(params_class=dojo.FSMParams,
213+
filename=self.temporary_file.name,
214+
num=num)
215+
self.assertEqual(len(best), num)
216+
217+
for parameters in best:
218+
self.assertIsInstance(parameters, dojo.FSMParams)
219+
220+
self.assertEqual(best[0].__repr__(), best_params)
221+
116222
def test_score_with_particular_players(self):
117223
"""
118224
These are players that are known to be difficult to pickle

0 commit comments

Comments
 (0)