Skip to content

Commit 2bb2b3f

Browse files
authored
Merge pull request #11 from MenloSystems/feature/out-of-bounds-resample
Resample when a point is out-of-bounds
2 parents 8981e3b + 66c52d1 commit 2bb2b3f

File tree

5 files changed

+217
-25
lines changed

5 files changed

+217
-25
lines changed

src/lib.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ pub use crate::options::CMAESOptions;
128128
pub use crate::parameters::Weights;
129129
#[cfg(feature = "plotters")]
130130
pub use crate::plotting::PlotOptions;
131+
pub use crate::sampling::Bounds;
132+
pub use crate::sampling::Constraints;
131133
pub use crate::termination::TerminationReason;
132134

133135
use std::f64;
@@ -243,18 +245,21 @@ impl<F> CMAES<F> {
243245
return Err(InvalidOptionsError::Cm);
244246
}
245247

246-
// Initialize point sampler
247248
let seed = options.seed.unwrap_or_else(rand::random);
249+
250+
// Initialize constant parameters according to the options
251+
let parameters = Parameters::from_options(&options, seed);
252+
253+
// Initialize point sampler
248254
let sampler = Sampler::new(
249255
dimensions,
256+
options.constraints,
257+
options.max_resamples,
250258
options.population_size,
251259
objective_function,
252260
seed,
253261
);
254262

255-
// Initialize constant parameters according to the options
256-
let parameters = Parameters::from_options(&options, seed);
257-
258263
// Initialize variable parameters
259264
let state = State::new(options.initial_mean, options.initial_step_size);
260265

src/options.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ use std::time::Duration;
66

77
use crate::mode::Mode;
88
use crate::parameters::Weights;
9+
use crate::sampling::Bounds;
10+
use crate::sampling::Constraints;
911
#[cfg(feature = "plotters")]
1012
use crate::PlotOptions;
1113
use crate::CMAES;
@@ -52,6 +54,12 @@ pub struct CMAESOptions {
5254
/// [`BIPOP`][crate::restart::BIPOP] restart strategies (see the [`restart`][crate::restart]
5355
/// module).
5456
pub population_size: usize,
57+
/// The constraints for the search. Resamples until all samples satisfy the constraints or
58+
/// `max_resamples` are hit. Default value is `None`.
59+
pub constraints: Option<Box<dyn Constraints>>,
60+
/// How many times to resample points in order to stay inside `bounds`.
61+
/// `None` disables the limit. Defaults to 10.
62+
pub max_resamples: Option<usize>,
5563
/// The distribution to use when assigning weights to individuals. Default value is
5664
/// [`Weights::Negative`].
5765
pub weights: Weights,
@@ -142,6 +150,8 @@ impl CMAESOptions {
142150
initial_mean,
143151
initial_step_size,
144152
population_size: 4 + (3.0 * (dimensions as f64).ln()).floor() as usize,
153+
constraints: None,
154+
max_resamples: Some(10),
145155
weights: Weights::Negative,
146156
parallel_update: false,
147157
cm: 1.0,
@@ -187,6 +197,25 @@ impl CMAESOptions {
187197
self
188198
}
189199

200+
/// Changes the bounds from the defalt value. Vector length must match the number of dimensions.
201+
/// Convenience method for constraints().
202+
pub fn bounds(mut self, lower: Vec<f64>, upper: Vec<f64>) -> Self {
203+
self.constraints = Some(Box::new(Bounds { lower, upper }));
204+
self
205+
}
206+
207+
/// Changes the constraints from the defalt value.
208+
pub fn constraints(mut self, constraints: Box<dyn Constraints>) -> Self {
209+
self.constraints = Some(constraints);
210+
self
211+
}
212+
213+
/// Changes the maximum number of resamples from the default value.
214+
pub fn max_resamples(mut self, max_resamples: Option<usize>) -> Self {
215+
self.max_resamples = max_resamples;
216+
self
217+
}
218+
190219
/// Changes the weight distribution from the default value. See [`Weights`] for
191220
/// possible distributions.
192221
pub fn weights(mut self, weights: Weights) -> Self {

src/restart/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,9 @@ impl Restarter {
317317
let seed = self.rng.gen();
318318

319319
// Apply default configuration (may be overridden by individual restart strategies)
320-
let mut options = self.default_options.clone()
320+
let mut options = self
321+
.default_options
322+
.clone()
321323
.initial_mean(initial_mean)
322324
.mode(self.mode)
323325
.parallel_update(self.parallel_update)

src/restart/options.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
use std::ops::RangeInclusive;
55
use std::time::Duration;
66

7-
8-
use super::{DEFAULT_INITIAL_STEP_SIZE, RestartStrategy, Restarter};
7+
use super::{RestartStrategy, Restarter, DEFAULT_INITIAL_STEP_SIZE};
98
use crate::{CMAESOptions, Mode};
109

1110
/// Represents invalid options for a `Restarter`.

src/sampling.rs

Lines changed: 175 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,42 @@ use crate::mode::Mode;
1111
use crate::state::State;
1212
use crate::{ObjectiveFunction, ParallelObjectiveFunction};
1313

14+
pub trait Constraints: Sync + std::fmt::Debug {
15+
fn meets_constraints(&self, x: &DVector<f64>) -> bool;
16+
fn clone_box(&self) -> Box<dyn Constraints>;
17+
}
18+
19+
impl Clone for Box<dyn Constraints> {
20+
fn clone(&self) -> Box<dyn Constraints> {
21+
self.clone_box()
22+
}
23+
}
24+
25+
#[derive(Debug, Clone)]
26+
pub struct Bounds {
27+
pub lower: Vec<f64>,
28+
pub upper: Vec<f64>,
29+
}
30+
31+
impl Constraints for Bounds {
32+
fn meets_constraints(&self, x: &DVector<f64>) -> bool {
33+
(0..x.len()).all(|i| x[i] >= self.lower[i] && x[i] <= self.upper[i])
34+
}
35+
36+
fn clone_box(&self) -> Box<dyn Constraints> {
37+
Box::new(self.clone())
38+
}
39+
}
40+
1441
/// A type for sampling and evaluating points from the distribution for each generation
1542
pub struct Sampler<F> {
1643
/// Number of dimensions to sample from
1744
dim: usize,
45+
/// If set, resamples until all points satisfy the constraints
46+
constraints: Option<Box<dyn Constraints>>,
47+
/// The maximum number of resamples.
48+
/// If this limit is hit, uses points even if they violate the constraints
49+
max_resamples: Option<usize>,
1850
/// Number of points to sample each generation
1951
population_size: usize,
2052
/// RNG from which all random numbers are sourced
@@ -26,9 +58,18 @@ pub struct Sampler<F> {
2658
}
2759

2860
impl<F> Sampler<F> {
29-
pub fn new(dim: usize, population_size: usize, objective_function: F, rng_seed: u64) -> Self {
61+
pub fn new(
62+
dim: usize,
63+
constraints: Option<Box<dyn Constraints>>,
64+
max_resamples: Option<usize>,
65+
population_size: usize,
66+
objective_function: F,
67+
rng_seed: u64,
68+
) -> Self {
3069
Self {
3170
dim,
71+
constraints,
72+
max_resamples,
3273
population_size,
3374
rng: ChaCha12Rng::seed_from_u64(rng_seed),
3475
objective_function,
@@ -49,22 +90,59 @@ impl<F> Sampler<F> {
4990
let normal = Normal::new(0.0, 1.0).unwrap();
5091

5192
// Random steps in the distribution N(0, I)
52-
let z = (0..self.population_size)
53-
.map(|_| {
54-
DVector::from_iterator(
55-
self.dim,
56-
(0..self.dim).map(|_| normal.sample(&mut self.rng)),
57-
)
58-
})
59-
.collect::<Vec<_>>();
60-
let transform = |zk| state.cov_transform() * zk;
61-
let y = if parallel_update {
62-
z.into_par_iter().map(transform).collect()
63-
} else {
64-
z.into_iter().map(transform).collect()
93+
let mut sample = |n: usize, constraints: Option<&dyn Constraints>| {
94+
let z = (0..n)
95+
.map(|_| {
96+
DVector::from_iterator(
97+
self.dim,
98+
(0..self.dim).map(|_| normal.sample(&mut self.rng)),
99+
)
100+
})
101+
.collect::<Vec<_>>();
102+
let transform = |zk| state.cov_transform() * zk;
103+
104+
let ok_constraints = |yk: &DVector<f64>| match constraints {
105+
Some(constraints) => {
106+
constraints.meets_constraints(&to_point(&yk, state.mean(), state.sigma()))
107+
}
108+
None => true,
109+
};
110+
111+
if parallel_update {
112+
z.into_par_iter()
113+
.map(transform)
114+
.filter(ok_constraints)
115+
.collect()
116+
} else {
117+
z.into_iter()
118+
.map(transform)
119+
.filter(ok_constraints)
120+
.collect()
121+
}
65122
};
66123

67-
// Evaluate and rank points
124+
let mut y: Vec<DVector<f64>> = vec![];
125+
126+
let mut i: usize = 0;
127+
loop {
128+
let remain = self.population_size - y.len();
129+
if remain == 0 {
130+
break;
131+
}
132+
133+
let mut constraints = self.constraints.as_ref().map(|x| x.as_ref());
134+
if let Some(max) = self.max_resamples {
135+
if i >= max {
136+
constraints = None;
137+
}
138+
}
139+
140+
let mut new_samps: Vec<_> = sample(remain, constraints);
141+
y.append(&mut new_samps);
142+
143+
i += 1;
144+
}
145+
68146
let mut points = evaluate_points(y, &mut self.objective_function)?;
69147

70148
self.function_evals += points.len();
@@ -138,6 +216,10 @@ pub struct EvaluatedPoint {
138216
value: f64,
139217
}
140218

219+
fn to_point(unscaled_step: &DVector<f64>, mean: &DVector<f64>, sigma: f64) -> DVector<f64> {
220+
mean + sigma * unscaled_step
221+
}
222+
141223
impl EvaluatedPoint {
142224
/// Returns a new `EvaluatedPoint` from the unscaled step from the mean, the mean, and the step
143225
/// size
@@ -149,7 +231,7 @@ impl EvaluatedPoint {
149231
sigma: f64,
150232
mut objective_function: F,
151233
) -> Result<Self, InvalidFunctionValueError> {
152-
let point = mean + sigma * &unscaled_step;
234+
let point = to_point(&unscaled_step, &mean, sigma);
153235
let value = objective_function(&point);
154236

155237
if value.is_nan() {
@@ -206,7 +288,14 @@ mod tests {
206288
fn test_sample() {
207289
let dim = 10;
208290
let population_size = 12;
209-
let mut sampler = Sampler::new(dim, population_size, Box::new(|_: &DVector<f64>| 0.0), 1);
291+
let mut sampler = Sampler::new(
292+
dim,
293+
None,
294+
None,
295+
population_size,
296+
Box::new(|_: &DVector<f64>| 0.0),
297+
1,
298+
);
210299
let state = State::new(vec![0.0; dim].into(), 2.0);
211300

212301
let n = 5;
@@ -220,6 +309,8 @@ mod tests {
220309

221310
let mut sampler_nan = Sampler::new(
222311
dim,
312+
None,
313+
None,
223314
population_size,
224315
Box::new(|_: &DVector<f64>| f64::NAN),
225316
1,
@@ -228,6 +319,72 @@ mod tests {
228319
assert!(sampler_nan.sample(&state, Mode::Minimize, false).is_err());
229320
}
230321

322+
#[test]
323+
fn test_resample() {
324+
let dim = 1;
325+
let population_size = 1;
326+
327+
let bounds = Bounds {
328+
lower: vec![0.0],
329+
upper: vec![1.0],
330+
};
331+
332+
let objective_function = |_: &DVector<f64>| 0.0;
333+
334+
// No resampling: Value should be out-of-bounds
335+
{
336+
let mut sampler = Sampler::new(
337+
dim,
338+
Some(Box::new(bounds.clone())),
339+
Some(0),
340+
population_size,
341+
objective_function,
342+
1,
343+
);
344+
let state = State::new(vec![0.0; dim].into(), 2.0);
345+
let individuals = sampler.sample(&state, Mode::Minimize, false).unwrap();
346+
347+
assert!(
348+
individuals[0].point[0] < bounds.lower[0]
349+
|| individuals[0].point[0] > bounds.upper[0]
350+
);
351+
}
352+
353+
// With limited resampling: Value should be in bounds
354+
{
355+
let mut sampler = Sampler::new(
356+
dim,
357+
Some(Box::new(bounds.clone())),
358+
Some(10),
359+
population_size,
360+
objective_function,
361+
1,
362+
);
363+
let state = State::new(vec![0.0; dim].into(), 2.0);
364+
let individuals = sampler.sample(&state, Mode::Minimize, false).unwrap();
365+
366+
assert!(individuals[0].point[0] >= bounds.lower[0]);
367+
assert!(individuals[0].point[0] <= bounds.upper[0]);
368+
}
369+
370+
// With unlimited resampling: Value should be in bounds
371+
{
372+
let mut sampler = Sampler::new(
373+
dim,
374+
Some(Box::new(bounds.clone())),
375+
None,
376+
population_size,
377+
objective_function,
378+
1,
379+
);
380+
let state = State::new(vec![0.0; dim].into(), 2.0);
381+
let individuals = sampler.sample(&state, Mode::Minimize, false).unwrap();
382+
383+
assert!(individuals[0].point[0] >= bounds.lower[0]);
384+
assert!(individuals[0].point[0] <= bounds.upper[0]);
385+
}
386+
}
387+
231388
fn sample_sort(mode: Mode, expected: [f64; 5]) {
232389
let mut counter = 0.0;
233390
let function = |_: &DVector<f64>| {
@@ -243,7 +400,7 @@ mod tests {
243400
let dim = 10;
244401
let population_size = expected.len();
245402

246-
let mut sampler = Sampler::new(dim, population_size, function, 1);
403+
let mut sampler = Sampler::new(dim, None, None, population_size, function, 1);
247404
let state = State::new(vec![0.0; dim].into(), 2.0);
248405

249406
let individuals = sampler.sample(&state, mode, false).unwrap();

0 commit comments

Comments
 (0)