Skip to content

Commit bb52f9b

Browse files
Fixes issue-428: TargetSpace.max() bug in constrained space (#429)
1 parent 0d5105d commit bb52f9b

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

bayes_opt/target_space.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, target_func, pbounds, constraint=None, random_state=None,
4343
If True, the optimizer will allow duplicate points to be registered.
4444
This behavior may be desired in high noise situations where repeatedly probing
4545
the same point will give different answers. In other situations, the acquisition
46-
may occasionaly generate a duplicate point.
46+
may occasionally generate a duplicate point.
4747
"""
4848
self.random_state = ensure_rng(random_state)
4949
self._allow_duplicate_points = allow_duplicate_points
@@ -70,7 +70,7 @@ def __init__(self, target_func, pbounds, constraint=None, random_state=None,
7070
self._constraint = constraint
7171

7272
if constraint is not None:
73-
# preallocated memory for constraint fulfillement
73+
# preallocated memory for constraint fulfillment
7474
if constraint.lb.size == 1:
7575
self._constraint_values = np.empty(shape=(0), dtype=float)
7676
else:
@@ -170,7 +170,7 @@ def register(self, params, target, constraint_value=None):
170170
171171
Notes
172172
-----
173-
runs in ammortized constant time
173+
runs in amortized constant time
174174
175175
Example
176176
-------
@@ -290,17 +290,29 @@ def max(self):
290290
if target_max is None:
291291
return None
292292

293-
target_max_idx = np.where(self.target == target_max)[0][0]
293+
if self._constraint is not None:
294+
allowed = self._constraint.allowed(self._constraint_values)
295+
296+
target = self.target[allowed]
297+
params = self.params[allowed]
298+
constraint_values = self.constraint_values[allowed]
299+
else:
300+
target = self.target
301+
params = self.params
302+
constraint_values = self.constraint_values
303+
304+
target_max_idx = np.where(target == target_max)[0][0]
305+
294306

295307
res = {
296308
'target': target_max,
297309
'params': dict(
298-
zip(self.keys, self.params[target_max_idx])
310+
zip(self.keys, params[target_max_idx])
299311
)
300312
}
301313

302314
if self._constraint is not None:
303-
res['constraint'] = self._constraint_values[target_max_idx]
315+
res['constraint'] = constraint_values[target_max_idx]
304316

305317
return res
306318

tests/test_target_space.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,17 @@ def test_max_with_constraint():
203203
space.probe(params={"p1": 1, "p2": 6}) # Unfeasible
204204
assert space.max() == {"params": {"p1": 2, "p2": 3}, "target": 5, "constraint": -1}
205205

206+
def test_max_with_constraint_identical_target_value():
207+
constraint = ConstraintModel(lambda p1, p2: p1-p2, -2, 2)
208+
space = TargetSpace(target_func, PBOUNDS, constraint=constraint)
209+
210+
assert space.max() == None
211+
space.probe(params={"p1": 1, "p2": 2}) # Feasible
212+
space.probe(params={"p1": 0, "p2": 5}) # Unfeasible, target value is 5, should not be selected
213+
space.probe(params={"p1": 5, "p2": 8}) # Unfeasible
214+
space.probe(params={"p1": 2, "p2": 3}) # Feasible, target value is also 5
215+
space.probe(params={"p1": 1, "p2": 6}) # Unfeasible
216+
assert space.max() == {"params": {"p1": 2, "p2": 3}, "target": 5, "constraint": -1}
206217

207218
def test_res():
208219
space = TargetSpace(target_func, PBOUNDS)

0 commit comments

Comments
 (0)