@@ -61,6 +61,11 @@ def __call__(self, study: optuna.study.Study, trial: FrozenTrial) -> None:
6161 )
6262
6363
64+ # Constraints
65+ def constraints (trial ):
66+ return trial .user_attrs ["constraints" ]
67+
68+
6469# Supported samplers
6570supported_samplers = ["TPE" , "CMAES" , "NSGAII" , "DUMMY" , "BOTORCH" , "QMC" ]
6671supported_pruners = ["NOP" , "SHA" , "HB" , "MED" ]
@@ -82,6 +87,7 @@ class MPIDistributedOptuna(Search):
8287 storage (Union[str, optuna.storages.BaseStorage], optional): Database used by Optuna. Defaults to ``None``.
8388 checkpoint (bool, optional): If results should be checkpointed regularly to the ``log_dir``. Defaults to ``True``.
8489 comm (MPI.Comm, optional): The MPI communicator. Defaults to ``None``.
90+ moo_lower_bounds ([type], optional): [description]. Defaults to ``None``.
8591
8692 Raises:
8793 ValueError: _description_
@@ -106,6 +112,7 @@ def __init__(
106112 checkpoint : bool = True ,
107113 n_initial_points : int = None ,
108114 comm : MPI .Comm = None ,
115+ moo_lower_bounds = None ,
109116 ** kwargs ,
110117 ):
111118 super ().__init__ (problem , evaluator , random_state , log_dir , verbose )
@@ -135,28 +142,46 @@ def __init__(
135142 2 * len (self ._problem ) if n_initial_points is None else n_initial_points
136143 )
137144
145+ # Constraints
146+ self ._moo_lower_bounds = moo_lower_bounds
147+ self ._constraints_func = None
148+ if moo_lower_bounds is not None :
149+ if len (moo_lower_bounds ) == n_objectives :
150+ self ._constraints_func = constraints
151+ else :
152+ raise ValueError (
153+ f"moo_lower_bounds should be of length { n_objectives } but is of length { len (moo_lower_bounds )} "
154+ )
155+
138156 # Setup the sampler
139157 if isinstance (sampler , optuna .samplers .BaseSampler ):
140158 pass
141159 elif isinstance (sampler , str ):
142160 sampler_seed = self ._random_state .randint (2 ** 31 )
143161 if sampler == "TPE" :
144162 sampler = optuna .samplers .TPESampler (
145- n_startup_trials = self ._n_initial_points , seed = sampler_seed
163+ n_startup_trials = self ._n_initial_points ,
164+ seed = sampler_seed ,
165+ constraints_func = self ._constraints_func ,
146166 )
147167 elif sampler == "CMAES" :
148168 sampler = optuna .samplers .CmaEsSampler (
149169 n_startup_trials = self ._n_initial_points , seed = sampler_seed
150170 )
151171 elif sampler == "NSGAII" :
152- sampler = optuna .samplers .NSGAIISampler (seed = sampler_seed )
172+ sampler = optuna .samplers .NSGAIISampler (
173+ seed = sampler_seed ,
174+ constraints_func = self ._constraints_func ,
175+ )
153176 elif sampler == "DUMMY" :
154177 sampler = optuna .samplers .RandomSampler (seed = sampler_seed )
155178 elif sampler == "BOTORCH" :
156179 from optuna .integration import BoTorchSampler
157180
158181 sampler = BoTorchSampler (
159- n_startup_trials = self ._n_initial_points , seed = sampler_seed
182+ n_startup_trials = self ._n_initial_points ,
183+ seed = sampler_seed ,
184+ constraints_func = self ._constraints_func ,
160185 )
161186 elif sampler == "QMC" :
162187 sampler = optuna .samplers .QMCSampler (seed = sampler_seed )
@@ -253,6 +278,15 @@ def objective_wrapper(trial):
253278 )
254279
255280 # TODO: optuna constraint
281+ if self ._moo_lower_bounds is not None :
282+ # https://optuna.readthedocs.io/en/stable/faq.html#how-can-i-optimize-a-model-with-some-constraints
283+ # Constraints which are considered feasible if less than or equal to zero.
284+ constraints = []
285+ for i , lbi in enumerate (self ._moo_lower_bounds ):
286+ if lbi is not None :
287+ ci = - (output ["objective" ][i ] - lbi ) # <= 0
288+ constraints .append (ci )
289+ trial .set_user_attr ("constraints" , tuple (constraints ))
256290
257291 data = {f"p:{ k } " : v for k , v in config .items ()}
258292 if isinstance (output ["objective" ], list ) or isinstance (
0 commit comments