Skip to content

Commit fd82c07

Browse files
committed
fixup: Generalize Bounds as Constraints (3/?)
1 parent e3f29e1 commit fd82c07

File tree

3 files changed

+36
-16
lines changed

3 files changed

+36
-16
lines changed

src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ pub use crate::mode::Mode;
126126
pub use crate::objective_function::{ObjectiveFunction, ParallelObjectiveFunction};
127127
pub use crate::options::CMAESOptions;
128128
pub use crate::parameters::Weights;
129+
pub use crate::sampling::Constraints;
129130
pub use crate::sampling::Bounds;
130131
#[cfg(feature = "plotters")]
131132
pub use crate::plotting::PlotOptions;
@@ -252,7 +253,7 @@ impl<F> CMAES<F> {
252253
// Initialize point sampler
253254
let sampler = Sampler::new(
254255
dimensions,
255-
options.bounds,
256+
options.constraints,
256257
options.max_resamples,
257258
options.population_size,
258259
objective_function,

src/options.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::time::Duration;
77
use crate::mode::Mode;
88
use crate::parameters::Weights;
99
use crate::sampling::Bounds;
10+
use crate::sampling::Constraints;
1011
#[cfg(feature = "plotters")]
1112
use crate::PlotOptions;
1213
use crate::CMAES;
@@ -53,9 +54,9 @@ pub struct CMAESOptions {
5354
/// [`BIPOP`][crate::restart::BIPOP] restart strategies (see the [`restart`][crate::restart]
5455
/// module).
5556
pub population_size: usize,
56-
/// The bounds within which to search. Resamples until all samples are in bounds or
57+
/// The constraints for the search. Resamples until all samples satisfy the constraints or
5758
/// `max_resamples` are hit. Default value is `None`.
58-
pub bounds: Option<Bounds>,
59+
pub constraints: Option<Box<dyn Constraints>>,
5960
/// How many times to resample points in order to stay inside `bounds`.
6061
/// `None` disables the limit. Defaults to 10.
6162
pub max_resamples: Option<usize>,
@@ -149,7 +150,7 @@ impl CMAESOptions {
149150
initial_mean,
150151
initial_step_size,
151152
population_size: 4 + (3.0 * (dimensions as f64).ln()).floor() as usize,
152-
bounds: None,
153+
constraints: None,
153154
max_resamples: Some(10),
154155
weights: Weights::Negative,
155156
parallel_update: false,
@@ -197,8 +198,15 @@ impl CMAESOptions {
197198
}
198199

199200
/// Changes the bounds from the defalt value. Vector length must match the number of dimensions.
201+
/// Convenience method for constraints().
200202
pub fn bounds(mut self, lower: Vec<f64>, upper: Vec<f64>) -> Self {
201-
self.bounds = Some(Bounds{lower, upper});
203+
self.constraints = Some(Box::new(Bounds{lower, upper}));
204+
self
205+
}
206+
207+
/// Changes the constraints from the defalt value.
208+
pub fn constraints(mut self, constraints: Box<dyn Constraints>) -> Self {
209+
self.constraints = Some(constraints);
202210
self
203211
}
204212

src/sampling.rs

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,15 @@ use crate::mode::Mode;
1111
use crate::state::State;
1212
use crate::{ObjectiveFunction, ParallelObjectiveFunction};
1313

14-
pub trait Constraints : Sync {
14+
pub trait Constraints : Sync + std::fmt::Debug {
1515
fn meets_constraints(&self, x: &DVector<f64>) -> bool;
16+
fn clone_box(&self) -> Box<dyn Constraints>;
17+
}
18+
19+
impl Clone for Box<dyn Constraints> {
20+
fn clone(&self) -> Box<dyn Constraints> {
21+
self.clone_box()
22+
}
1623
}
1724

1825
#[derive(Debug, Clone)]
@@ -25,16 +32,20 @@ impl Constraints for Bounds {
2532
fn meets_constraints(&self, x: &DVector<f64>) -> bool {
2633
(0..x.len()).all(|i| x[i] >= self.lower[i] && x[i] <= self.upper[i])
2734
}
35+
36+
fn clone_box(&self) -> Box<dyn Constraints> {
37+
Box::new(self.clone())
38+
}
2839
}
2940

3041
/// A type for sampling and evaluating points from the distribution for each generation
3142
pub struct Sampler<F> {
3243
/// Number of dimensions to sample from
3344
dim: usize,
34-
/// If set, resamples until all points are within bounds
45+
/// If set, resamples until all points satisfy the constraints
3546
constraints: Option<Box<dyn Constraints>>,
3647
/// The maximum number of resamples.
37-
/// If this limit is hit, uses points even if they are outside the bounds
48+
/// If this limit is hit, uses points even if they violate the constraints
3849
max_resamples: Option<usize>,
3950
/// Number of points to sample each generation
4051
population_size: usize,
@@ -47,10 +58,10 @@ pub struct Sampler<F> {
4758
}
4859

4960
impl<F> Sampler<F> {
50-
pub fn new(dim: usize, bounds: Option<Box<dyn Constraints>>, max_resamples: Option<usize>, population_size: usize, objective_function: F, rng_seed: u64) -> Self {
61+
pub fn new(dim: usize, constraints: Option<Box<dyn Constraints>>, max_resamples: Option<usize>, population_size: usize, objective_function: F, rng_seed: u64) -> Self {
5162
Self {
5263
dim,
53-
constraints: bounds,
64+
constraints,
5465
max_resamples,
5566
population_size,
5667
rng: ChaCha12Rng::seed_from_u64(rng_seed),
@@ -85,15 +96,15 @@ impl<F> Sampler<F> {
8596

8697
match constraints {
8798
Some(constraints) => {
88-
let in_bounds = |yk: &DVector<f64>| {
99+
let ok_constraints = |yk: &DVector<f64>| {
89100
let point = to_point(&yk, state.mean(), state.sigma());
90101
constraints.meets_constraints(&point)
91102
};
92103

93104
if parallel_update {
94-
z.into_par_iter().map(transform).filter(in_bounds).collect()
105+
z.into_par_iter().map(transform).filter(ok_constraints).collect()
95106
} else {
96-
z.into_iter().map(transform).filter(in_bounds).collect()
107+
z.into_iter().map(transform).filter(ok_constraints).collect()
97108
}
98109
},
99110
None => {
@@ -115,14 +126,14 @@ impl<F> Sampler<F> {
115126
break;
116127
}
117128

118-
let mut bounds = self.constraints.as_ref().map(|x| x.as_ref());
129+
let mut constraints = self.constraints.as_ref().map(|x| x.as_ref());
119130
if let Some(max) = self.max_resamples {
120131
if i >= max {
121-
bounds = None;
132+
constraints = None;
122133
}
123134
}
124135

125-
let mut new_samps: Vec<_> = sample(remain, bounds);
136+
let mut new_samps: Vec<_> = sample(remain, constraints);
126137
y.append(&mut new_samps);
127138

128139
i += 1;

0 commit comments

Comments
 (0)