Skip to content

Commit baf10da

Browse files
committed
Merge ConstrainedTargetSpace into TargetSpace
1 parent fac54c0 commit baf10da

File tree

2 files changed

+107
-127
lines changed

2 files changed

+107
-127
lines changed

bayes_opt/bayesian_optimization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from bayes_opt.constraint import ConstraintModel
55

6-
from .target_space import TargetSpace, ConstrainedTargetSpace
6+
from .target_space import TargetSpace
77
from .event import Events, DEFAULT_EVENTS
88
from .logger import _get_default_logger
99
from .util import UtilityFunction, acq_max, ensure_rng
@@ -107,7 +107,7 @@ def __init__(self,
107107
# Data structure containing the function to be optimized, the
108108
# bounds of its domain, and a record of the evaluations we have
109109
# done so far
110-
self._space = TargetSpace(f, pbounds, random_state)
110+
self._space = TargetSpace(f, pbounds, random_state=random_state)
111111
self.is_constrained = False
112112
else:
113113
constraint_ = ConstraintModel(
@@ -116,11 +116,11 @@ def __init__(self,
116116
constraint.ub,
117117
random_state=random_state
118118
)
119-
self._space = ConstrainedTargetSpace(
119+
self._space = TargetSpace(
120120
f,
121-
constraint_,
122121
pbounds,
123-
random_state
122+
constraint=constraint_,
123+
random_state=random_state
124124
)
125125
self.is_constrained = True
126126

bayes_opt/target_space.py

Lines changed: 102 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class TargetSpace(object):
2323
>>> y = space.register_point(x)
2424
>>> assert self.max_point()['max_val'] == y
2525
"""
26-
def __init__(self, target_func, pbounds, random_state=None):
26+
def __init__(self, target_func, pbounds, constraint=None, random_state=None):
2727
"""
2828
Parameters
2929
----------
@@ -57,6 +57,16 @@ def __init__(self, target_func, pbounds, random_state=None):
5757
# keep track of unique points we have seen so far
5858
self._cache = {}
5959

60+
61+
self._constraint = constraint
62+
63+
if constraint is not None:
64+
# preallocated memory for constraint fulfillement
65+
if constraint.lb.size == 1:
66+
self._constraint_values = np.empty(shape=(0), dtype=float)
67+
else:
68+
self._constraint_values = np.empty(shape=(0, constraint.lb.size), dtype=float)
69+
6070
def __contains__(self, x):
6171
return _hashable(x) in self._cache
6272

@@ -87,6 +97,15 @@ def keys(self):
8797
@property
8898
def bounds(self):
8999
return self._bounds
100+
101+
@property
102+
def constraint(self):
103+
return self._constraint
104+
105+
@property
106+
def constraint_values(self):
107+
if self._constraint is not None:
108+
return self._constraint_values
90109

91110
def params_to_array(self, params):
92111
try:
@@ -123,7 +142,7 @@ def _as_array(self, x):
123142
"expected number of parameters ({}).".format(len(self.keys)))
124143
return x
125144

126-
def register(self, params, target):
145+
def register(self, params, target, constraint_value=None):
127146
"""
128147
Append a point and its target value to the known data.
129148
@@ -160,12 +179,19 @@ def register(self, params, target):
160179
if x in self:
161180
raise KeyError('Data point {} is not unique'.format(x))
162181

163-
# Insert data into unique dictionary
164-
self._cache[_hashable(x.ravel())] = target
165182

166183
self._params = np.concatenate([self._params, x.reshape(1, -1)])
167184
self._target = np.concatenate([self._target, [target]])
168185

186+
if constraint_value is None:
187+
# Insert data into unique dictionary
188+
self._cache[_hashable(x.ravel())] = target
189+
else:
190+
# Insert data into unique dictionary
191+
self._cache[_hashable(x.ravel())] = (target, constraint_value)
192+
self._constraint_values = np.concatenate([self._constraint_values,
193+
[constraint_value]])
194+
169195
def probe(self, params):
170196
"""
171197
Evaulates a single point x, to obtain the value y and then records them
@@ -188,12 +214,19 @@ def probe(self, params):
188214
x = self._as_array(params)
189215

190216
try:
191-
target = self._cache[_hashable(x)]
217+
return self._cache[_hashable(x)]
192218
except KeyError:
193219
params = dict(zip(self._keys, x))
194220
target = self.target_func(**params)
195-
self.register(x, target)
196-
return target
221+
222+
if self._constraint is None:
223+
self.register(x, target)
224+
return target
225+
else:
226+
constraint_value = self._constraint.eval(**params)
227+
self.register(x, target, constraint_value)
228+
return target, constraint_value
229+
197230

198231
def random_sample(self):
199232
"""
@@ -218,26 +251,71 @@ def random_sample(self):
218251
return data.ravel()
219252

220253
def max(self):
221-
"""Get maximum target value found and corresponding parameters."""
222-
try:
223-
res = {
224-
'target': self.target.max(),
225-
'params': dict(
226-
zip(self.keys, self.params[self.target.argmax()])
227-
)
228-
}
229-
except ValueError:
230-
res = {}
231-
return res
254+
"""Get maximum target value found and corresponding parameters.
255+
256+
If there is a constraint present, the maximum value that fulfills the
257+
constraint is returned."""
258+
if self._constraint is None:
259+
try:
260+
res = {
261+
'target': self.target.max(),
262+
'params': dict(
263+
zip(self.keys, self.params[self.target.argmax()])
264+
)
265+
}
266+
except ValueError:
267+
res = {}
268+
return res
269+
else:
270+
allowed = self._constraint.allowed(self._constraint_values)
271+
if allowed.any():
272+
# Getting of all points that fulfill the constraints, find the
273+
# one with the maximum value for the target function.
274+
sorted = np.argsort(self.target)
275+
idx = sorted[allowed[sorted]][-1]
276+
# there must be a better way to do this, right?
277+
res = {
278+
'target': self.target[idx],
279+
'params': dict(
280+
zip(self.keys, self.params[idx])
281+
),
282+
'constraint': self._constraint_values[idx]
283+
}
284+
else:
285+
res = {
286+
'target': None,
287+
'params': None,
288+
'constraint': None
289+
}
290+
return res
232291

233292
def res(self):
234-
"""Get all target values found and corresponding parametes."""
235-
params = [dict(zip(self.keys, p)) for p in self.params]
293+
"""Get all target values and constraint fulfillment for all parameters.
294+
"""
295+
if self._constraint is None:
296+
params = [dict(zip(self.keys, p)) for p in self.params]
236297

237-
return [
238-
{"target": target, "params": param}
239-
for target, param in zip(self.target, params)
240-
]
298+
return [
299+
{"target": target, "params": param}
300+
for target, param in zip(self.target, params)
301+
]
302+
else:
303+
params = [dict(zip(self.keys, p)) for p in self.params]
304+
305+
return [
306+
{
307+
"target": target,
308+
"constraint": constraint_value,
309+
"params": param,
310+
"allowed": allowed
311+
}
312+
for target, constraint_value, param, allowed in zip(
313+
self.target,
314+
self._constraint_values,
315+
params,
316+
self._constraint.allowed(self._constraint_values)
317+
)
318+
]
241319

242320
def set_bounds(self, new_bounds):
243321
"""
@@ -251,101 +329,3 @@ def set_bounds(self, new_bounds):
251329
for row, key in enumerate(self.keys):
252330
if key in new_bounds:
253331
self._bounds[row] = new_bounds[key]
254-
255-
256-
class ConstrainedTargetSpace(TargetSpace):
257-
"""
258-
Expands TargetSpace to incorporate constraints.
259-
"""
260-
def __init__(self,
261-
target_func,
262-
constraint: ConstraintModel,
263-
pbounds,
264-
random_state=None):
265-
super().__init__(target_func, pbounds, random_state)
266-
267-
self._constraint = constraint
268-
269-
# preallocated memory for constraint fulfillement
270-
if constraint.lb.size == 1:
271-
self._constraint_values = np.empty(shape=(0), dtype=float)
272-
else:
273-
self._constraint_values = np.empty(shape=(0, constraint.lb.size), dtype=float)
274-
275-
@property
276-
def constraint(self):
277-
return self._constraint
278-
279-
@property
280-
def constraint_values(self):
281-
return self._constraint_values
282-
283-
def register(self, params, target, constraint_value):
284-
x = self._as_array(params)
285-
if x in self:
286-
raise KeyError('Data point {} is not unique'.format(x))
287-
288-
# Insert data into unique dictionary
289-
self._cache[_hashable(x.ravel())] = (target, constraint_value)
290-
291-
self._params = np.concatenate([self._params, x.reshape(1, -1)])
292-
self._target = np.concatenate([self._target, [target]])
293-
self._constraint_values = np.concatenate([self._constraint_values,
294-
[constraint_value]])
295-
296-
def probe(self, params):
297-
x = self._as_array(params)
298-
299-
try:
300-
return self._cache[_hashable(x)]
301-
except KeyError:
302-
params = dict(zip(self._keys, x))
303-
target = self.target_func(**params)
304-
constraint_value = self._constraint.eval(**params)
305-
self.register(x, target, constraint_value)
306-
return target, constraint_value
307-
308-
def max(self):
309-
"""Get maximum target value found and corresponding parametes provided
310-
that they fulfill the constraints."""
311-
allowed = self._constraint.allowed(self._constraint_values)
312-
if allowed.any():
313-
# Getting of all points that fulfill the constraints, find the
314-
# one with the maximum value for the target function.
315-
sorted = np.argsort(self.target)
316-
idx = sorted[allowed[sorted]][-1]
317-
# there must be a better way to do this, right?
318-
res = {
319-
'target': self.target[idx],
320-
'params': dict(
321-
zip(self.keys, self.params[idx])
322-
),
323-
'constraint': self._constraint_values[idx]
324-
}
325-
else:
326-
res = {
327-
'target': None,
328-
'params': None,
329-
'constraint': None
330-
}
331-
return res
332-
333-
def res(self):
334-
"""Get all target values and constraint fulfillment for all parameters.
335-
"""
336-
params = [dict(zip(self.keys, p)) for p in self.params]
337-
338-
return [
339-
{
340-
"target": target,
341-
"constraint": constraint_value,
342-
"params": param,
343-
"allowed": allowed
344-
}
345-
for target, constraint_value, param, allowed in zip(
346-
self.target,
347-
self._constraint_values,
348-
params,
349-
self._constraint.allowed(self._constraint_values)
350-
)
351-
]

0 commit comments

Comments
 (0)