Skip to content

Commit 451e9bd

Browse files
committed
Add a variety of progress bars to the fingerprint.
Most are straight forward, for the creation of the points I had to refactor things slightly (as it's a product of two lists there wasn't a straightforward way of using tqdm on it). Closes #768.
1 parent e454a80 commit 451e9bd

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

axelrod/fingerprint.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import axelrod as axl
22
import numpy as np
33
import matplotlib.pyplot as plt
4+
import tqdm
45
from axelrod.strategy_transformers import JossAnnTransformer, DualTransformer
56
from axelrod.interaction_utils import compute_final_score_per_turn, read_interactions_from_file
67
from axelrod import on_windows
@@ -11,7 +12,7 @@
1112
Point = namedtuple('Point', 'x y')
1213

1314

14-
def create_points(step):
15+
def create_points(step, progress_bar=True):
1516
"""Creates a set of Points over the unit square.
1617
1718
A Point has coordinates (x, y). This function constructs points that are
@@ -23,15 +24,29 @@ def create_points(step):
2324
step : float
2425
The separation between each Point. Smaller steps will produce more
2526
Points with coordinates that will be closer together.
27+
progress_bar : bool
28+
Whether or not to create a progress bar which will be updated
2629
2730
Returns
2831
----------
2932
points : list
3033
of Point objects with coordinates (x, y)
3134
"""
3235
num = int((1 / step) // 1) + 1
33-
points = [Point(j, k) for j in np.linspace(0, 1, num)
34-
for k in np.linspace(0, 1, num)]
36+
37+
if progress_bar:
38+
p_bar = tqdm.tqdm(total=num ** 2, desc="Generating points")
39+
40+
points = []
41+
for x in np.linspace(0, 1, num):
42+
for y in np.linspace(0, 1, num):
43+
points.append(Point(x, y))
44+
45+
if progress_bar:
46+
p_bar.update()
47+
48+
if progress_bar:
49+
p_bar.close()
3550

3651
return points
3752

@@ -76,7 +91,7 @@ def create_jossann(point, probe):
7691
return joss_ann
7792

7893
@staticmethod
79-
def create_edges(points):
94+
def create_edges(points, progress_bar=True):
8095
"""Creates a set of edges for a spatial tournament.
8196
8297
Constructs edges that correspond to `points`. All edges begin at 0, and
@@ -86,6 +101,9 @@ def create_edges(points):
86101
----------
87102
points : list
88103
of Point objects with coordinates (x, y)
104+
progress_bar : bool
105+
Whether or not to create a progress bar which will be updated
106+
89107
90108
Returns
91109
----------
@@ -95,10 +113,12 @@ def create_edges(points):
95113
corresponding probe (+1 to allow for including the Strategy and its
96114
Dual).
97115
"""
116+
if progress_bar:
117+
points = tqdm.tqdm(points, desc="Generating network edges")
98118
edges = [(0, index + 1) for index, point in enumerate(points)]
99119
return edges
100120

101-
def create_probes(self, probe, points):
121+
def create_probes(self, probe, points, progress_bar=True):
102122
"""Creates a set of probe strategies over the unit square.
103123
104124
Constructs probe strategies that correspond to points with coordinates
@@ -110,24 +130,31 @@ def create_probes(self, probe, points):
110130
A class that must be descended from axelrod.strategies.
111131
points : list
112132
of Point objects with coordinates (x, y)
133+
progress_bar : bool
134+
Whether or not to create a progress bar which will be updated
113135
114136
Returns
115137
----------
116138
probes : list
117139
A list of `JossAnnTransformer` players with parameters that
118140
correspond to point.
119141
"""
142+
if progress_bar:
143+
points = tqdm.tqdm(points, desc="Generating probes")
120144
probes = [self.create_jossann(point, probe) for point in points]
121145
return probes
122146

123-
def construct_tournament_elements(self, step):
147+
def construct_tournament_elements(self, step, progress_bar=True):
124148
"""Build the elements required for a spatial tournament
125149
126150
Parameters
127151
----------
128152
step : float
129153
The separation between each Point. Smaller steps will
130154
produce more Points that will be closer together.
155+
progress_bar : bool
156+
Whether or not to create a progress bar which will be updated
157+
131158
132159
Returns
133160
----------
@@ -142,11 +169,11 @@ def construct_tournament_elements(self, step):
142169
original player, the second is the dual, the rest are the probes.
143170
144171
"""
145-
probe_points = create_points(step)
146-
self.points = probe_points
147-
edges = self.create_edges(probe_points)
172+
self.points = create_points(step, progress_bar=progress_bar)
173+
edges = self.create_edges(self.points, progress_bar=progress_bar)
174+
probe_players = self.create_probes(self.probe, self.points,
175+
progress_bar=progress_bar)
148176

149-
probe_players = self.create_probes(self.probe, probe_points)
150177
tournament_players = [self.strategy()] + probe_players
151178

152179
return edges, tournament_players
@@ -221,7 +248,8 @@ def fingerprint(self, turns=50, repetitions=10, step=0.01, processes=None,
221248
outputfile = NamedTemporaryFile(mode='w')
222249
filename = outputfile.name
223250

224-
edges, tourn_players = self.construct_tournament_elements(step)
251+
edges, tourn_players = self.construct_tournament_elements(step,
252+
progress_bar=progress_bar)
225253
self.step = step
226254
self.spatial_tournament = axl.SpatialTournament(tourn_players,
227255
turns=turns,
@@ -235,7 +263,8 @@ def fingerprint(self, turns=50, repetitions=10, step=0.01, processes=None,
235263
if in_memory:
236264
self.interactions = self.spatial_tournament.interactions_dict
237265
else:
238-
self.interactions = read_interactions_from_file(filename)
266+
self.interactions = read_interactions_from_file(filename,
267+
progress_bar=progress_bar)
239268

240269
self.data = self.generate_data(self.interactions, self.points, edges)
241270
return self.data

axelrod/tests/unit/test_fingerprint.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,24 @@ def test_init(self):
3333
self.assertEqual(fingerprint.probe, probe)
3434

3535
def test_create_points(self):
36-
test_points = create_points(0.5)
36+
test_points = create_points(0.5, progress_bar=False)
3737
self.assertEqual(test_points, self.expected_points)
3838

3939
def test_create_probes(self):
4040
af = AshlockFingerprint(self.strategy, self.probe)
41-
probes = af.create_probes(self.probe, self.expected_points)
41+
probes = af.create_probes(self.probe, self.expected_points,
42+
progress_bar=False)
4243
self.assertEqual(len(probes), 9)
4344

4445
def test_create_edges(self):
4546
af = AshlockFingerprint(self.strategy, self.probe)
46-
edges = af.create_edges(self.expected_points)
47+
edges = af.create_edges(self.expected_points, progress_bar=False)
4748
self.assertEqual(edges, self.expected_edges)
4849

4950
def test_construct_tournament_elemets(self):
5051
af = AshlockFingerprint(self.strategy, self.probe)
51-
edges, tournament_players = af.construct_tournament_elements(0.5)
52+
edges, tournament_players = af.construct_tournament_elements(0.5,
53+
progress_bar=False)
5254
self.assertEqual(edges, self.expected_edges)
5355
self.assertEqual(len(tournament_players), 10)
5456
self.assertEqual(tournament_players[0].__class__, af.strategy)

0 commit comments

Comments
 (0)