@@ -23,7 +23,7 @@ class TargetSpace(object):
2323 >>> y = space.register_point(x)
2424 >>> assert self.max_point()['max_val'] == y
2525 """
26- def __init__ (self , target_func , pbounds , random_state = None ):
26+ def __init__ (self , target_func , pbounds , constraint = None , random_state = None ):
2727 """
2828 Parameters
2929 ----------
@@ -57,6 +57,16 @@ def __init__(self, target_func, pbounds, random_state=None):
5757 # keep track of unique points we have seen so far
5858 self ._cache = {}
5959
60+
61+ self ._constraint = constraint
62+
63+ if constraint is not None :
64+ # preallocated memory for constraint fulfillement
65+ if constraint .lb .size == 1 :
66+ self ._constraint_values = np .empty (shape = (0 ), dtype = float )
67+ else :
68+ self ._constraint_values = np .empty (shape = (0 , constraint .lb .size ), dtype = float )
69+
6070 def __contains__ (self , x ):
6171 return _hashable (x ) in self ._cache
6272
@@ -87,6 +97,15 @@ def keys(self):
8797 @property
8898 def bounds (self ):
8999 return self ._bounds
100+
101+ @property
102+ def constraint (self ):
103+ return self ._constraint
104+
105+ @property
106+ def constraint_values (self ):
107+ if self ._constraint is not None :
108+ return self ._constraint_values
90109
91110 def params_to_array (self , params ):
92111 try :
@@ -123,7 +142,7 @@ def _as_array(self, x):
123142 "expected number of parameters ({})." .format (len (self .keys )))
124143 return x
125144
126- def register (self , params , target ):
145+ def register (self , params , target , constraint_value = None ):
127146 """
128147 Append a point and its target value to the known data.
129148
@@ -160,12 +179,19 @@ def register(self, params, target):
160179 if x in self :
161180 raise KeyError ('Data point {} is not unique' .format (x ))
162181
163- # Insert data into unique dictionary
164- self ._cache [_hashable (x .ravel ())] = target
165182
166183 self ._params = np .concatenate ([self ._params , x .reshape (1 , - 1 )])
167184 self ._target = np .concatenate ([self ._target , [target ]])
168185
186+ if constraint_value is None :
187+ # Insert data into unique dictionary
188+ self ._cache [_hashable (x .ravel ())] = target
189+ else :
190+ # Insert data into unique dictionary
191+ self ._cache [_hashable (x .ravel ())] = (target , constraint_value )
192+ self ._constraint_values = np .concatenate ([self ._constraint_values ,
193+ [constraint_value ]])
194+
169195 def probe (self , params ):
170196 """
171197 Evaulates a single point x, to obtain the value y and then records them
@@ -188,12 +214,19 @@ def probe(self, params):
188214 x = self ._as_array (params )
189215
190216 try :
191- target = self ._cache [_hashable (x )]
217+ return self ._cache [_hashable (x )]
192218 except KeyError :
193219 params = dict (zip (self ._keys , x ))
194220 target = self .target_func (** params )
195- self .register (x , target )
196- return target
221+
222+ if self ._constraint is None :
223+ self .register (x , target )
224+ return target
225+ else :
226+ constraint_value = self ._constraint .eval (** params )
227+ self .register (x , target , constraint_value )
228+ return target , constraint_value
229+
197230
198231 def random_sample (self ):
199232 """
@@ -218,26 +251,71 @@ def random_sample(self):
218251 return data .ravel ()
219252
220253 def max (self ):
221- """Get maximum target value found and corresponding parameters."""
222- try :
223- res = {
224- 'target' : self .target .max (),
225- 'params' : dict (
226- zip (self .keys , self .params [self .target .argmax ()])
227- )
228- }
229- except ValueError :
230- res = {}
231- return res
254+ """Get maximum target value found and corresponding parameters.
255+
256+ If there is a constraint present, the maximum value that fulfills the
257+ constraint is returned."""
258+ if self ._constraint is None :
259+ try :
260+ res = {
261+ 'target' : self .target .max (),
262+ 'params' : dict (
263+ zip (self .keys , self .params [self .target .argmax ()])
264+ )
265+ }
266+ except ValueError :
267+ res = {}
268+ return res
269+ else :
270+ allowed = self ._constraint .allowed (self ._constraint_values )
271+ if allowed .any ():
272+ # Getting of all points that fulfill the constraints, find the
273+ # one with the maximum value for the target function.
274+ sorted = np .argsort (self .target )
275+ idx = sorted [allowed [sorted ]][- 1 ]
276+ # there must be a better way to do this, right?
277+ res = {
278+ 'target' : self .target [idx ],
279+ 'params' : dict (
280+ zip (self .keys , self .params [idx ])
281+ ),
282+ 'constraint' : self ._constraint_values [idx ]
283+ }
284+ else :
285+ res = {
286+ 'target' : None ,
287+ 'params' : None ,
288+ 'constraint' : None
289+ }
290+ return res
232291
233292 def res (self ):
234- """Get all target values found and corresponding parametes."""
235- params = [dict (zip (self .keys , p )) for p in self .params ]
293+ """Get all target values and constraint fulfillment for all parameters.
294+ """
295+ if self ._constraint is None :
296+ params = [dict (zip (self .keys , p )) for p in self .params ]
236297
237- return [
238- {"target" : target , "params" : param }
239- for target , param in zip (self .target , params )
240- ]
298+ return [
299+ {"target" : target , "params" : param }
300+ for target , param in zip (self .target , params )
301+ ]
302+ else :
303+ params = [dict (zip (self .keys , p )) for p in self .params ]
304+
305+ return [
306+ {
307+ "target" : target ,
308+ "constraint" : constraint_value ,
309+ "params" : param ,
310+ "allowed" : allowed
311+ }
312+ for target , constraint_value , param , allowed in zip (
313+ self .target ,
314+ self ._constraint_values ,
315+ params ,
316+ self ._constraint .allowed (self ._constraint_values )
317+ )
318+ ]
241319
242320 def set_bounds (self , new_bounds ):
243321 """
@@ -251,101 +329,3 @@ def set_bounds(self, new_bounds):
251329 for row , key in enumerate (self .keys ):
252330 if key in new_bounds :
253331 self ._bounds [row ] = new_bounds [key ]
254-
255-
256- class ConstrainedTargetSpace (TargetSpace ):
257- """
258- Expands TargetSpace to incorporate constraints.
259- """
260- def __init__ (self ,
261- target_func ,
262- constraint : ConstraintModel ,
263- pbounds ,
264- random_state = None ):
265- super ().__init__ (target_func , pbounds , random_state )
266-
267- self ._constraint = constraint
268-
269- # preallocated memory for constraint fulfillement
270- if constraint .lb .size == 1 :
271- self ._constraint_values = np .empty (shape = (0 ), dtype = float )
272- else :
273- self ._constraint_values = np .empty (shape = (0 , constraint .lb .size ), dtype = float )
274-
275- @property
276- def constraint (self ):
277- return self ._constraint
278-
279- @property
280- def constraint_values (self ):
281- return self ._constraint_values
282-
283- def register (self , params , target , constraint_value ):
284- x = self ._as_array (params )
285- if x in self :
286- raise KeyError ('Data point {} is not unique' .format (x ))
287-
288- # Insert data into unique dictionary
289- self ._cache [_hashable (x .ravel ())] = (target , constraint_value )
290-
291- self ._params = np .concatenate ([self ._params , x .reshape (1 , - 1 )])
292- self ._target = np .concatenate ([self ._target , [target ]])
293- self ._constraint_values = np .concatenate ([self ._constraint_values ,
294- [constraint_value ]])
295-
296- def probe (self , params ):
297- x = self ._as_array (params )
298-
299- try :
300- return self ._cache [_hashable (x )]
301- except KeyError :
302- params = dict (zip (self ._keys , x ))
303- target = self .target_func (** params )
304- constraint_value = self ._constraint .eval (** params )
305- self .register (x , target , constraint_value )
306- return target , constraint_value
307-
308- def max (self ):
309- """Get maximum target value found and corresponding parametes provided
310- that they fulfill the constraints."""
311- allowed = self ._constraint .allowed (self ._constraint_values )
312- if allowed .any ():
313- # Getting of all points that fulfill the constraints, find the
314- # one with the maximum value for the target function.
315- sorted = np .argsort (self .target )
316- idx = sorted [allowed [sorted ]][- 1 ]
317- # there must be a better way to do this, right?
318- res = {
319- 'target' : self .target [idx ],
320- 'params' : dict (
321- zip (self .keys , self .params [idx ])
322- ),
323- 'constraint' : self ._constraint_values [idx ]
324- }
325- else :
326- res = {
327- 'target' : None ,
328- 'params' : None ,
329- 'constraint' : None
330- }
331- return res
332-
333- def res (self ):
334- """Get all target values and constraint fulfillment for all parameters.
335- """
336- params = [dict (zip (self .keys , p )) for p in self .params ]
337-
338- return [
339- {
340- "target" : target ,
341- "constraint" : constraint_value ,
342- "params" : param ,
343- "allowed" : allowed
344- }
345- for target , constraint_value , param , allowed in zip (
346- self .target ,
347- self ._constraint_values ,
348- params ,
349- self ._constraint .allowed (self ._constraint_values )
350- )
351- ]
0 commit comments