Skip to content

Commit 0e15482

Browse files
jhoofwijkbasnijholt
authored andcommitted
allow passing a convex hull to LearnerND
1 parent 1ffe220 commit 0e15482

File tree

1 file changed

+31
-16
lines changed

1 file changed

+31
-16
lines changed

adaptive/learner/learnerND.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,10 @@ class LearnerND(BaseLearner):
107107
func: callable
108108
The function to learn. Must take a tuple of N real
109109
parameters and return a real number or an arraylike of length M.
110-
bounds : list of 2-tuples
110+
bounds : list of 2-tuples or `scipy.spatial.ConvexHull`
111111
A list ``[(a_1, b_1), (a_2, b_2), ..., (a_n, b_n)]`` containing bounds,
112112
one pair per dimension.
113+
Or a ConvexHull that defines the boundary of the domain.
113114
loss_per_simplex : callable, optional
114115
A function that returns the loss for a simplex.
115116
If not provided, then a default is used, which uses
@@ -150,14 +151,21 @@ class LearnerND(BaseLearner):
150151
"""
151152

152153
def __init__(self, func, bounds, loss_per_simplex=None):
153-
self.ndim = len(bounds)
154154
self._vdim = None
155155
self.loss_per_simplex = loss_per_simplex or default_loss
156-
self.bounds = tuple(tuple(map(float, b)) for b in bounds)
157156
self.data = OrderedDict()
158157
self.pending_points = set()
159158

160-
self._bounds_points = list(map(tuple, itertools.product(*bounds)))
159+
if isinstance(bounds, scipy.spatial.ConvexHull):
160+
hull_points = bounds.points[bounds.vertices]
161+
self._bounds_points = sorted(list(map(tuple, hull_points)))
162+
self._bbox = tuple(zip(hull_points.min(axis=0), hull_points.max(axis=0)))
163+
self._interior = scipy.spatial.Delaunay(self._bounds_points)
164+
else:
165+
self._bounds_points = sorted(list(map(tuple, itertools.product(*bounds))))
166+
self._bbox = tuple(tuple(map(float, b)) for b in bounds)
167+
168+
self.ndim = len(self._bbox)
161169

162170
self.function = func
163171
self._tri = None
@@ -169,7 +177,7 @@ def __init__(self, func, bounds, loss_per_simplex=None):
169177
self._subtriangulations = dict() # simplex → triangulation
170178

171179
# scale to unit
172-
self._transform = np.linalg.inv(np.diag(np.diff(bounds).flat))
180+
self._transform = np.linalg.inv(np.diag(np.diff(self._bbox).flat))
173181

174182
# create a private random number generator with fixed seed
175183
self._random = random.Random(1)
@@ -275,7 +283,12 @@ def _simplex_exists(self, simplex):
275283

276284
def inside_bounds(self, point):
277285
"""Check whether a point is inside the bounds."""
278-
return all(mn <= p <= mx for p, (mn, mx) in zip(point, self.bounds))
286+
if hasattr(self, '_interior'):
287+
return self._interior.find_simplex(point, tol=1e-8) >= 0
288+
else:
289+
eps = 1e-8
290+
return all((mn - eps) <= p <= (mx + eps) for p, (mn, mx)
291+
in zip(point, self._bbox))
279292

280293
def tell_pending(self, point, *, simplex=None):
281294
point = tuple(point)
@@ -349,11 +362,13 @@ def _ask_point_without_known_simplices(self):
349362
assert not self._bounds_available
350363
# pick a random point inside the bounds
351364
# XXX: change this into picking a point based on volume loss
352-
a = np.diff(self.bounds).flat
353-
b = np.array(self.bounds)[:, 0]
354-
r = np.array([self._random.random() for _ in range(self.ndim)])
355-
p = r * a + b
356-
p = tuple(p)
365+
a = np.diff(self._bbox).flat
366+
b = np.array(self._bbox)[:, 0]
367+
p = None
368+
while p is None or not self.inside_bounds(p):
369+
r = np.array([self._random.random() for _ in range(self.ndim)])
370+
p = r * a + b
371+
p = tuple(p)
357372

358373
self.tell_pending(p)
359374
return p, np.inf
@@ -489,10 +504,10 @@ def plot(self, n=None, tri_alpha=0):
489504
if self.vdim > 1:
490505
raise NotImplementedError('holoviews currently does not support',
491506
'3D surface plots in bokeh.')
492-
if len(self.bounds) != 2:
507+
if len(self.ndim) != 2:
493508
raise NotImplementedError("Only 2D plots are implemented: You can "
494509
"plot a 2D slice with 'plot_slice'.")
495-
x, y = self.bounds
510+
x, y = self._bbox
496511
lbrt = x[0], y[0], x[1], y[1]
497512

498513
if len(self.data) >= 4:
@@ -549,7 +564,7 @@ def plot_slice(self, cut_mapping, n=None):
549564
raise NotImplementedError('multidimensional output not yet'
550565
' supported by `plot_slice`')
551566
n = n or 201
552-
values = [cut_mapping.get(i, np.linspace(*self.bounds[i], n))
567+
values = [cut_mapping.get(i, np.linspace(*self._bbox[i], n))
553568
for i in range(self.ndim)]
554569
ind = next(i for i in range(self.ndim) if i not in cut_mapping)
555570
x = values[ind]
@@ -574,9 +589,9 @@ def plot_slice(self, cut_mapping, n=None):
574589
xys = [xs[:, None], ys[None, :]]
575590
values = [cut_mapping[i] if i in cut_mapping
576591
else xys.pop(0) * (b[1] - b[0]) + b[0]
577-
for i, b in enumerate(self.bounds)]
592+
for i, b in enumerate(self._bbox)]
578593

579-
lbrt = [b for i, b in enumerate(self.bounds)
594+
lbrt = [b for i, b in enumerate(self._bbox)
580595
if i not in cut_mapping]
581596
lbrt = np.reshape(lbrt, (2, 2)).T.flatten().tolist()
582597

0 commit comments

Comments
 (0)