@@ -11,10 +11,42 @@ use crate::mode::Mode;
1111use crate :: state:: State ;
1212use crate :: { ObjectiveFunction , ParallelObjectiveFunction } ;
1313
14+ pub trait Constraints : Sync + std:: fmt:: Debug {
15+ fn meets_constraints ( & self , x : & DVector < f64 > ) -> bool ;
16+ fn clone_box ( & self ) -> Box < dyn Constraints > ;
17+ }
18+
19+ impl Clone for Box < dyn Constraints > {
20+ fn clone ( & self ) -> Box < dyn Constraints > {
21+ self . clone_box ( )
22+ }
23+ }
24+
25+ #[ derive( Debug , Clone ) ]
26+ pub struct Bounds {
27+ pub lower : Vec < f64 > ,
28+ pub upper : Vec < f64 > ,
29+ }
30+
31+ impl Constraints for Bounds {
32+ fn meets_constraints ( & self , x : & DVector < f64 > ) -> bool {
33+ ( 0 ..x. len ( ) ) . all ( |i| x[ i] >= self . lower [ i] && x[ i] <= self . upper [ i] )
34+ }
35+
36+ fn clone_box ( & self ) -> Box < dyn Constraints > {
37+ Box :: new ( self . clone ( ) )
38+ }
39+ }
40+
1441/// A type for sampling and evaluating points from the distribution for each generation
1542pub struct Sampler < F > {
1643 /// Number of dimensions to sample from
1744 dim : usize ,
45+ /// If set, resamples until all points satisfy the constraints
46+ constraints : Option < Box < dyn Constraints > > ,
47+ /// The maximum number of resamples.
48+ /// If this limit is hit, uses points even if they violate the constraints
49+ max_resamples : Option < usize > ,
1850 /// Number of points to sample each generation
1951 population_size : usize ,
2052 /// RNG from which all random numbers are sourced
@@ -26,9 +58,18 @@ pub struct Sampler<F> {
2658}
2759
2860impl < F > Sampler < F > {
29- pub fn new ( dim : usize , population_size : usize , objective_function : F , rng_seed : u64 ) -> Self {
61+ pub fn new (
62+ dim : usize ,
63+ constraints : Option < Box < dyn Constraints > > ,
64+ max_resamples : Option < usize > ,
65+ population_size : usize ,
66+ objective_function : F ,
67+ rng_seed : u64 ,
68+ ) -> Self {
3069 Self {
3170 dim,
71+ constraints,
72+ max_resamples,
3273 population_size,
3374 rng : ChaCha12Rng :: seed_from_u64 ( rng_seed) ,
3475 objective_function,
@@ -49,22 +90,59 @@ impl<F> Sampler<F> {
4990 let normal = Normal :: new ( 0.0 , 1.0 ) . unwrap ( ) ;
5091
5192 // Random steps in the distribution N(0, I)
52- let z = ( 0 ..self . population_size )
53- . map ( |_| {
54- DVector :: from_iterator (
55- self . dim ,
56- ( 0 ..self . dim ) . map ( |_| normal. sample ( & mut self . rng ) ) ,
57- )
58- } )
59- . collect :: < Vec < _ > > ( ) ;
60- let transform = |zk| state. cov_transform ( ) * zk;
61- let y = if parallel_update {
62- z. into_par_iter ( ) . map ( transform) . collect ( )
63- } else {
64- z. into_iter ( ) . map ( transform) . collect ( )
93+ let mut sample = |n : usize , constraints : Option < & dyn Constraints > | {
94+ let z = ( 0 ..n)
95+ . map ( |_| {
96+ DVector :: from_iterator (
97+ self . dim ,
98+ ( 0 ..self . dim ) . map ( |_| normal. sample ( & mut self . rng ) ) ,
99+ )
100+ } )
101+ . collect :: < Vec < _ > > ( ) ;
102+ let transform = |zk| state. cov_transform ( ) * zk;
103+
104+ let ok_constraints = |yk : & DVector < f64 > | match constraints {
105+ Some ( constraints) => {
106+ constraints. meets_constraints ( & to_point ( & yk, state. mean ( ) , state. sigma ( ) ) )
107+ }
108+ None => true ,
109+ } ;
110+
111+ if parallel_update {
112+ z. into_par_iter ( )
113+ . map ( transform)
114+ . filter ( ok_constraints)
115+ . collect ( )
116+ } else {
117+ z. into_iter ( )
118+ . map ( transform)
119+ . filter ( ok_constraints)
120+ . collect ( )
121+ }
65122 } ;
66123
67- // Evaluate and rank points
124+ let mut y: Vec < DVector < f64 > > = vec ! [ ] ;
125+
126+ let mut i: usize = 0 ;
127+ loop {
128+ let remain = self . population_size - y. len ( ) ;
129+ if remain == 0 {
130+ break ;
131+ }
132+
133+ let mut constraints = self . constraints . as_ref ( ) . map ( |x| x. as_ref ( ) ) ;
134+ if let Some ( max) = self . max_resamples {
135+ if i >= max {
136+ constraints = None ;
137+ }
138+ }
139+
140+ let mut new_samps: Vec < _ > = sample ( remain, constraints) ;
141+ y. append ( & mut new_samps) ;
142+
143+ i += 1 ;
144+ }
145+
68146 let mut points = evaluate_points ( y, & mut self . objective_function ) ?;
69147
70148 self . function_evals += points. len ( ) ;
@@ -138,6 +216,10 @@ pub struct EvaluatedPoint {
138216 value : f64 ,
139217}
140218
219+ fn to_point ( unscaled_step : & DVector < f64 > , mean : & DVector < f64 > , sigma : f64 ) -> DVector < f64 > {
220+ mean + sigma * unscaled_step
221+ }
222+
141223impl EvaluatedPoint {
142224 /// Returns a new `EvaluatedPoint` from the unscaled step from the mean, the mean, and the step
143225 /// size
@@ -149,7 +231,7 @@ impl EvaluatedPoint {
149231 sigma : f64 ,
150232 mut objective_function : F ,
151233 ) -> Result < Self , InvalidFunctionValueError > {
152- let point = mean + sigma * & unscaled_step ;
234+ let point = to_point ( & unscaled_step , & mean, sigma) ;
153235 let value = objective_function ( & point) ;
154236
155237 if value. is_nan ( ) {
@@ -206,7 +288,14 @@ mod tests {
206288 fn test_sample ( ) {
207289 let dim = 10 ;
208290 let population_size = 12 ;
209- let mut sampler = Sampler :: new ( dim, population_size, Box :: new ( |_: & DVector < f64 > | 0.0 ) , 1 ) ;
291+ let mut sampler = Sampler :: new (
292+ dim,
293+ None ,
294+ None ,
295+ population_size,
296+ Box :: new ( |_: & DVector < f64 > | 0.0 ) ,
297+ 1 ,
298+ ) ;
210299 let state = State :: new ( vec ! [ 0.0 ; dim] . into ( ) , 2.0 ) ;
211300
212301 let n = 5 ;
@@ -220,6 +309,8 @@ mod tests {
220309
221310 let mut sampler_nan = Sampler :: new (
222311 dim,
312+ None ,
313+ None ,
223314 population_size,
224315 Box :: new ( |_: & DVector < f64 > | f64:: NAN ) ,
225316 1 ,
@@ -228,6 +319,72 @@ mod tests {
228319 assert ! ( sampler_nan. sample( & state, Mode :: Minimize , false ) . is_err( ) ) ;
229320 }
230321
322+ #[ test]
323+ fn test_resample ( ) {
324+ let dim = 1 ;
325+ let population_size = 1 ;
326+
327+ let bounds = Bounds {
328+ lower : vec ! [ 0.0 ] ,
329+ upper : vec ! [ 1.0 ] ,
330+ } ;
331+
332+ let objective_function = |_: & DVector < f64 > | 0.0 ;
333+
334+ // No resampling: Value should be out-of-bounds
335+ {
336+ let mut sampler = Sampler :: new (
337+ dim,
338+ Some ( Box :: new ( bounds. clone ( ) ) ) ,
339+ Some ( 0 ) ,
340+ population_size,
341+ objective_function,
342+ 1 ,
343+ ) ;
344+ let state = State :: new ( vec ! [ 0.0 ; dim] . into ( ) , 2.0 ) ;
345+ let individuals = sampler. sample ( & state, Mode :: Minimize , false ) . unwrap ( ) ;
346+
347+ assert ! (
348+ individuals[ 0 ] . point[ 0 ] < bounds. lower[ 0 ]
349+ || individuals[ 0 ] . point[ 0 ] > bounds. upper[ 0 ]
350+ ) ;
351+ }
352+
353+ // With limited resampling: Value should be in bounds
354+ {
355+ let mut sampler = Sampler :: new (
356+ dim,
357+ Some ( Box :: new ( bounds. clone ( ) ) ) ,
358+ Some ( 10 ) ,
359+ population_size,
360+ objective_function,
361+ 1 ,
362+ ) ;
363+ let state = State :: new ( vec ! [ 0.0 ; dim] . into ( ) , 2.0 ) ;
364+ let individuals = sampler. sample ( & state, Mode :: Minimize , false ) . unwrap ( ) ;
365+
366+ assert ! ( individuals[ 0 ] . point[ 0 ] >= bounds. lower[ 0 ] ) ;
367+ assert ! ( individuals[ 0 ] . point[ 0 ] <= bounds. upper[ 0 ] ) ;
368+ }
369+
370+ // With unlimited resampling: Value should be in bounds
371+ {
372+ let mut sampler = Sampler :: new (
373+ dim,
374+ Some ( Box :: new ( bounds. clone ( ) ) ) ,
375+ None ,
376+ population_size,
377+ objective_function,
378+ 1 ,
379+ ) ;
380+ let state = State :: new ( vec ! [ 0.0 ; dim] . into ( ) , 2.0 ) ;
381+ let individuals = sampler. sample ( & state, Mode :: Minimize , false ) . unwrap ( ) ;
382+
383+ assert ! ( individuals[ 0 ] . point[ 0 ] >= bounds. lower[ 0 ] ) ;
384+ assert ! ( individuals[ 0 ] . point[ 0 ] <= bounds. upper[ 0 ] ) ;
385+ }
386+ }
387+
231388 fn sample_sort ( mode : Mode , expected : [ f64 ; 5 ] ) {
232389 let mut counter = 0.0 ;
233390 let function = |_: & DVector < f64 > | {
@@ -243,7 +400,7 @@ mod tests {
243400 let dim = 10 ;
244401 let population_size = expected. len ( ) ;
245402
246- let mut sampler = Sampler :: new ( dim, population_size, function, 1 ) ;
403+ let mut sampler = Sampler :: new ( dim, None , None , population_size, function, 1 ) ;
247404 let state = State :: new ( vec ! [ 0.0 ; dim] . into ( ) , 2.0 ) ;
248405
249406 let individuals = sampler. sample ( & state, mode, false ) . unwrap ( ) ;
0 commit comments