@@ -96,7 +96,7 @@ def __str__(self):
9696 return output
9797
9898
99- def solve_main (objfun , x0 , args , xl , xu , npt , rhobeg , rhoend , maxfun , nruns_so_far , nf_so_far , nx_so_far , nsamples , params ,
99+ def solve_main (objfun , x0 , args , xl , xu , projections , npt , rhobeg , rhoend , maxfun , nruns_so_far , nf_so_far , nx_so_far , nsamples , params ,
100100 diagnostic_info , scaling_changes , f0_avg_old = None , f0_nsamples_old = None , do_logging = True , print_progress = False ):
101101 # Evaluate at x0 (keep nf, nx correct and check for f small)
102102 if f0_avg_old is None :
@@ -144,7 +144,7 @@ def solve_main(objfun, x0, args, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_f
144144 nx = nx_so_far
145145
146146 # Initialise controller
147- control = Controller (objfun , x0 , args , f0_avg , num_samples_run , xl , xu , npt , rhobeg , rhoend , nf , nx , maxfun , params , scaling_changes , do_logging = do_logging )
147+ control = Controller (objfun , x0 , args , f0_avg , num_samples_run , xl , xu , projections , npt , rhobeg , rhoend , nf , nx , maxfun , params , scaling_changes , do_logging = do_logging )
148148
149149 # Initialise interpolation set
150150 number_of_samples = max (nsamples (control .delta , control .rho , 0 , nruns_so_far ), 1 )
@@ -665,7 +665,7 @@ def solve_main(objfun, x0, args, xl, xu, npt, rhobeg, rhoend, maxfun, nruns_so_f
665665 return x , f , gradmin , hessmin , nsamples , control .nf , control .nx , nruns_so_far , exit_info , diagnostic_info
666666
667667
668- def solve (objfun , x0 , args = (), bounds = None , npt = None , rhobeg = None , rhoend = 1e-8 , maxfun = None , nsamples = None , user_params = None ,
668+ def solve (objfun , x0 , args = (), bounds = None , projections = None , npt = None , rhobeg = None , rhoend = 1e-8 , maxfun = None , nsamples = None , user_params = None ,
669669 objfun_has_noise = False , seek_global_minimum = False , scaling_within_bounds = False , do_logging = True , print_progress = False ):
670670 n = len (x0 )
671671 if type (x0 ) == list :
@@ -694,7 +694,11 @@ def solve(objfun, x0, args=(), bounds=None, npt=None, rhobeg=None, rhoend=1e-8,
694694 if (xl is None or xu is None ) and scaling_within_bounds :
695695 scaling_within_bounds = False
696696 warnings .warn ("Ignoring scaling_within_bounds=True for unconstrained problem/1-sided bounds" , RuntimeWarning )
697-
697+
698+ if (projections is not None ) and scaling_within_bounds :
699+ scaling_within_bounds = False
700+ warnings .warn ("Ignoring scaling_within_bounds=True for problems with projections given" , RuntimeWarning )
701+
698702 exit_info = None
699703 if seek_global_minimum and (xl is None or xu is None ):
700704 exit_info = ExitInformation (EXIT_INPUT_ERROR , "If seeking global minimum, must specify upper and lower bounds" )
@@ -761,6 +765,9 @@ def solve(objfun, x0, args=(), bounds=None, npt=None, rhobeg=None, rhoend=1e-8,
761765 if exit_info is None and np .min (xu - xl ) < 2.0 * rhobeg :
762766 exit_info = ExitInformation (EXIT_INPUT_ERROR , "gap between lower and upper must be at least 2*rhobeg" )
763767
768+ if exit_info is None and projections is not None and type (projections ) != list :
769+ exit_info = ExitInformation (EXIT_INPUT_ERROR , "projections must be a list of functions" )
770+
764771 if maxfun <= npt :
765772 warnings .warn ("maxfun <= npt: Are you sure your budget is large enough?" , RuntimeWarning )
766773
@@ -792,12 +799,12 @@ def solve(objfun, x0, args=(), bounds=None, npt=None, rhobeg=None, rhoend=1e-8,
792799 return results
793800
794801 # Enforce lower & upper bounds on x0
795- idx = (x0 <= xl )
802+ idx = (x0 < xl )
796803 if np .any (idx ):
797804 warnings .warn ("x0 below lower bound, adjusting" , RuntimeWarning )
798805 x0 [idx ] = xl [idx ]
799806
800- idx = (x0 >= xu )
807+ idx = (x0 > xu )
801808 if np .any (idx ):
802809 warnings .warn ("x0 above upper bound, adjusting" , RuntimeWarning )
803810 x0 [idx ] = xu [idx ]
@@ -808,7 +815,7 @@ def solve(objfun, x0, args=(), bounds=None, npt=None, rhobeg=None, rhoend=1e-8,
808815 nf = 0
809816 nx = 0
810817 xmin , fmin , gradmin , hessmin , nsamples_min , nf , nx , nruns , exit_info , diagnostic_info = \
811- solve_main (objfun , x0 , args , xl , xu , npt , rhobeg , rhoend , maxfun , nruns , nf , nx , nsamples , params ,
818+ solve_main (objfun , x0 , args , xl , xu , projections , npt , rhobeg , rhoend , maxfun , nruns , nf , nx , nsamples , params ,
812819 diagnostic_info , scaling_changes , do_logging = do_logging , print_progress = print_progress )
813820
814821 # Hard restarts loop
@@ -829,11 +836,11 @@ def solve(objfun, x0, args=(), bounds=None, npt=None, rhobeg=None, rhoend=1e-8,
829836 % (fmin , nf , _rhobeg , _rhoend ))
830837 if params ("restarts.hard.use_old_fk" ):
831838 xmin2 , fmin2 , gradmin2 , hessmin2 , nsamples2 , nf , nx , nruns , exit_info , diagnostic_info = \
832- solve_main (objfun , xmin , args , xl , xu , npt , _rhobeg , _rhoend , maxfun , nruns , nf , nx , nsamples , params ,
839+ solve_main (objfun , xmin , args , xl , xu , projections , npt , _rhobeg , _rhoend , maxfun , nruns , nf , nx , nsamples , params ,
833840 diagnostic_info , scaling_changes , f0_avg_old = fmin , f0_nsamples_old = nsamples_min , do_logging = do_logging , print_progress = print_progress )
834841 else :
835842 xmin2 , fmin2 , gradmin2 , hessmin2 , nsamples2 , nf , nx , nruns , exit_info , diagnostic_info = \
836- solve_main (objfun , xmin , args , xl , xu , npt , _rhobeg , _rhoend , maxfun , nruns , nf , nx , nsamples , params ,
843+ solve_main (objfun , xmin , args , xl , xu , projections , npt , _rhobeg , _rhoend , maxfun , nruns , nf , nx , nsamples , params ,
837844 diagnostic_info , scaling_changes , do_logging = do_logging , print_progress = print_progress )
838845
839846 if fmin2 < fmin or np .isnan (fmin ):
0 commit comments