Skip to content

Commit a12e810

Browse files
committed
Merge branch 'develop'
2 parents 1054529 + ae59308 commit a12e810

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

py-radiate/tests/unit/test_metrics.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,59 @@ def test_generation_metrics(self, random_seed):
4848
assert metrics[key].min() is not None
4949
assert metrics[key].max() is not None
5050
assert metrics[key].count() is not None
51+
52+
@pytest.mark.integration
53+
def test_metrics_from_events(self, random_seed):
54+
class ScoreDistributionPlotter(rd.EventHandler):
55+
"""
56+
Subscriber class to handle events and track metrics.
57+
We will use this to plot score distributions over generations then
58+
display the plot when the engine stops.
59+
"""
60+
61+
def __init__(self):
62+
super().__init__(rd.EventType.EPOCH_COMPLETE)
63+
64+
def on_event(self, event: rd.EngineEvent) -> None:
65+
metrics = event.metrics()
66+
for key in metrics.keys():
67+
assert key in metrics
68+
if key == "scores":
69+
assert metrics[key].seq_mean() is not None
70+
assert metrics[key].seq_stddev() is not None
71+
assert metrics[key].seq_variance() is not None
72+
assert metrics[key].seq_kurt() is not None
73+
assert metrics[key].seq_skew() is not None
74+
assert metrics[key].seq_min() is not None
75+
assert metrics[key].seq_max() is not None
76+
assert (
77+
metrics[key].seq_count() == 100
78+
) # num phenotypes in population
79+
assert len(metrics[key].seq_last()) == metrics[key].seq_count()
80+
elif key == "time" or "step" in key:
81+
assert metrics[key].time_last() is not None
82+
assert metrics[key].time_sum() is not None
83+
assert metrics[key].time_mean() is not None
84+
assert metrics[key].time_stddev() is not None
85+
assert metrics[key].time_variance() is not None
86+
assert metrics[key].time_min() is not None
87+
assert metrics[key].time_max() is not None
88+
elif "step" not in key:
89+
assert metrics[key].value_last() is not None
90+
assert metrics[key].mean() is not None
91+
assert metrics[key].stddev() is not None
92+
assert metrics[key].variance() is not None
93+
assert metrics[key].skew() is not None
94+
assert metrics[key].min() is not None
95+
assert metrics[key].max() is not None
96+
assert metrics[key].count() is not None
97+
98+
num_genes = 5
99+
engine = rd.GeneticEngine(
100+
codec=rd.IntCodec.vector(num_genes, init_range=(0, 10)),
101+
fitness_func=lambda x: sum(x),
102+
objective="min",
103+
subscribe=[ScoreDistributionPlotter()]
104+
)
105+
106+
engine.run([rd.ScoreLimit(0), rd.GenerationsLimit(500)])

0 commit comments

Comments
 (0)