Skip to content

Commit 712c3ac

Browse files
committed
Make Population take instances.
Ensures init kwargs get passed through.
1 parent e84fb30 commit 712c3ac

File tree

3 files changed

+37
-6
lines changed

3 files changed

+37
-6
lines changed

src/axelrod_dojo/utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,21 @@ def crossover(self, other):
138138
pass
139139

140140

141-
def score_params(params, objective, opponents, weights=None):
141+
def score_params(params, objective, opponents,
142+
opponent_init_kwargs=None,
143+
weights=None):
142144
"""
143145
Return the overall mean score of a Params instance.
144146
"""
145147
scores_for_all_opponents = []
146148
player = params.player()
147149

148-
for opponent_class in opponents:
150+
if opponent_init_kwargs is None:
151+
opponent_init_kwargs = [{} for _ in opponents]
152+
153+
for opponent_class, init_kwargs in zip(opponents, opponent_init_kwargs):
149154
player.reset()
150-
opponent = opponent_class()
155+
opponent = opponent_class(**init_kwargs)
151156
scores_for_this_opponent = objective(player, opponent)
152157
mean_vs_opponent = mean(scores_for_this_opponent)
153158
scores_for_all_opponents.append(mean_vs_opponent)
@@ -173,8 +178,10 @@ def __init__(self, params_class, params_args, size, objective, output_filename,
173178
self.bottleneck = bottleneck
174179
if opponents is None:
175180
self.opponents = axl.short_run_time_strategies
181+
self.opponent_init_kwargs = None
176182
else:
177-
self.opponents = opponents
183+
self.opponents = [p.__class__ for p in opponents]
184+
self.opponent_init_kwargs = [p.init_kwargs for p in opponents]
178185
self.generation = 0
179186
self.params_args = params_args
180187
self.population = [params_class(*params_args) for _ in range(self.size)]
@@ -185,6 +192,7 @@ def score_all(self):
185192
self.population,
186193
repeat(self.objective),
187194
repeat(self.opponents),
195+
repeat(self.opponent_init_kwargs),
188196
repeat(self.weights))
189197
results = self.pool.starmap(score_params, starmap_params)
190198
return results

tests/integration/test_fsm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_score(self):
1717
repetitions = 5
1818
num_states = 2
1919
mutation_rate = .1
20-
opponents = axl.demo_strategies
20+
opponents = [s() for s in axl.demo_strategies]
2121
size = 10
2222

2323
objective = dojo.prepare_objective(name=name,
@@ -68,7 +68,7 @@ def test_score_with_weights(self):
6868
repetitions = 5
6969
num_states = 2
7070
mutation_rate = .1
71-
opponents = axl.demo_strategies
71+
opponents = [s() for s in axl.demo_strategies]
7272
size = 10
7373

7474
objective = dojo.prepare_objective(name=name,

tests/unit/test_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,29 @@ def test_score(self):
193193
expected_score = 2.0949
194194
self.assertEqual(score, expected_score)
195195

196+
def test_with_init_kwargs(self):
197+
axl.seed(0)
198+
opponents = [axl.Random]
199+
opponent_init_kwargs = [{"p": 0}] # Creating a defector
200+
objective = utils.prepare_objective()
201+
params = DummyParams()
202+
score = utils.score_params(params,
203+
objective=objective,
204+
opponent_init_kwargs=opponent_init_kwargs,
205+
opponents=opponents)
206+
expected_score = 0
207+
self.assertEqual(score, expected_score)
208+
209+
opponent_init_kwargs = [{"p": 1}] # Creating a cooperator
210+
objective = utils.prepare_objective()
211+
params = DummyParams()
212+
score = utils.score_params(params,
213+
objective=objective,
214+
opponent_init_kwargs=opponent_init_kwargs,
215+
opponents=opponents)
216+
expected_score = 3.0
217+
self.assertEqual(score, expected_score)
218+
196219
def test_score_with_weights(self):
197220
axl.seed(0)
198221
opponents = axl.demo_strategies

0 commit comments

Comments
 (0)