Skip to content

Commit 2abe024

Browse files
committed
probability wheel iterator
1 parent a12e810 commit 2abe024

File tree

2 files changed

+6
-26
lines changed

2 files changed

+6
-26
lines changed

crates/radiate-selectors/src/roulette.rs

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use radiate_core::{Chromosome, Objective, Optimize, Population, Select, pareto, random_provider};
1+
use crate::ProbabilityWheelIterator;
2+
use radiate_core::{Chromosome, Objective, Optimize, Population, Select, pareto};
23

34
#[derive(Debug, Default)]
45
pub struct RouletteSelector;
@@ -47,23 +48,8 @@ impl<C: Chromosome + Clone> Select<C> for RouletteSelector {
4748
}
4849
};
4950

50-
let mut cdf = Vec::with_capacity(fitness_values.len());
51-
let mut acc = 0.0;
52-
for &p in &fitness_values {
53-
acc += p;
54-
cdf.push(acc);
55-
}
56-
let total = *cdf.last().unwrap_or(&1.0);
57-
58-
let mut out = Vec::with_capacity(count);
59-
for _ in 0..count {
60-
let r = random_provider::random::<f32>() * total;
61-
let idx = cdf
62-
.binary_search_by(|x| x.partial_cmp(&r).unwrap())
63-
.unwrap_or_else(|i| i);
64-
out.push(population[idx].clone());
65-
}
66-
67-
out.into()
51+
ProbabilityWheelIterator::new(&fitness_values, count)
52+
.map(|idx| population[idx].clone())
53+
.collect()
6854
}
6955
}

py-radiate/tests/unit/test_metrics.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,6 @@ def test_generation_metrics(self, random_seed):
5252
@pytest.mark.integration
5353
def test_metrics_from_events(self, random_seed):
5454
class ScoreDistributionPlotter(rd.EventHandler):
55-
"""
56-
Subscriber class to handle events and track metrics.
57-
We will use this to plot score distributions over generations then
58-
display the plot when the engine stops.
59-
"""
60-
6155
def __init__(self):
6256
super().__init__(rd.EventType.EPOCH_COMPLETE)
6357

@@ -100,7 +94,7 @@ def on_event(self, event: rd.EngineEvent) -> None:
10094
codec=rd.IntCodec.vector(num_genes, init_range=(0, 10)),
10195
fitness_func=lambda x: sum(x),
10296
objective="min",
103-
subscribe=[ScoreDistributionPlotter()]
97+
subscribe=[ScoreDistributionPlotter()],
10498
)
10599

106100
engine.run([rd.ScoreLimit(0), rd.GenerationsLimit(500)])

0 commit comments

Comments
 (0)