Skip to content

Commit 01fb120

Browse files
jhoofwijkbasnijholt
authored andcommitted
Add support for neighbours in loss computation in LearnerND (#185)
* add support for neighbours in loss computation in LearnerND * make loss function only accept one single exploration factor * redefine the call signature of the curvature loss function * make learnerND curvature loss work without errors * add function to triangulation to find opposing vertices * add some more tests for get_opposing_vertices * make use of the new api * fix spaces (after comma and indentation) * remove unused import * remove comment * add exception if too many neighbors * remove trailing whitespace
1 parent 1debc15 commit 01fb120

File tree

8 files changed

+377
-82
lines changed

8 files changed

+377
-82
lines changed

adaptive/learner/base_learner.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,58 @@
77
from adaptive.utils import save, load
88

99

10+
def uses_nth_neighbors(n):
11+
"""Decorator to specify how many neighboring intervals the loss function uses.
12+
13+
Wraps loss functions to indicate that they expect intervals together
14+
with ``n`` nearest neighbors
15+
16+
The loss function will then receive the data of the N nearest neighbors
17+
(``nth_neighbors``) aling with the data of the interval itself in a dict.
18+
The `~adaptive.Learner1D` will also make sure that the loss is updated
19+
whenever one of the ``nth_neighbors`` changes.
20+
21+
Examples
22+
--------
23+
24+
The next function is a part of the `curvature_loss_function` function.
25+
26+
>>> @uses_nth_neighbors(1)
27+
... def triangle_loss(xs, ys):
28+
... xs = [x for x in xs if x is not None]
29+
... ys = [y for y in ys if y is not None]
30+
...
31+
... if len(xs) == 2: # we do not have enough points for a triangle
32+
... return xs[1] - xs[0]
33+
...
34+
... N = len(xs) - 2 # number of constructed triangles
35+
... if isinstance(ys[0], Iterable):
36+
... pts = [(x, *y) for x, y in zip(xs, ys)]
37+
... vol = simplex_volume_in_embedding
38+
... else:
39+
... pts = [(x, y) for x, y in zip(xs, ys)]
40+
... vol = volume
41+
... return sum(vol(pts[i:i+3]) for i in range(N)) / N
42+
43+
Or you may define a loss that favours the (local) minima of a function,
44+
assuming that you know your function will have a single float as output.
45+
46+
>>> @uses_nth_neighbors(1)
47+
... def local_minima_resolving_loss(xs, ys):
48+
... dx = xs[2] - xs[1] # the width of the interval of interest
49+
...
50+
... if not ((ys[0] is not None and ys[0] > ys[1])
51+
... or (ys[3] is not None and ys[3] > ys[2])):
52+
... return loss * 100
53+
...
54+
... return loss
55+
"""
56+
def _wrapped(loss_per_interval):
57+
loss_per_interval.nth_neighbors = n
58+
return loss_per_interval
59+
return _wrapped
60+
61+
1062
class BaseLearner(metaclass=abc.ABCMeta):
1163
"""Base class for algorithms for learning a function 'f: X → Y'.
1264

adaptive/learner/learner1D.py

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -10,65 +10,13 @@
1010
import sortedcontainers
1111
import sortedcollections
1212

13-
from adaptive.learner.base_learner import BaseLearner
13+
from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
1414
from adaptive.learner.learnerND import volume
1515
from adaptive.learner.triangulation import simplex_volume_in_embedding
1616
from adaptive.notebook_integration import ensure_holoviews
1717
from adaptive.utils import cache_latest
1818

1919

20-
def uses_nth_neighbors(n):
21-
"""Decorator to specify how many neighboring intervals the loss function uses.
22-
23-
Wraps loss functions to indicate that they expect intervals together
24-
with ``n`` nearest neighbors
25-
26-
The loss function will then receive the data of the N nearest neighbors
27-
(``nth_neighbors``) aling with the data of the interval itself in a dict.
28-
The `~adaptive.Learner1D` will also make sure that the loss is updated
29-
whenever one of the ``nth_neighbors`` changes.
30-
31-
Examples
32-
--------
33-
34-
The next function is a part of the `curvature_loss_function` function.
35-
36-
>>> @uses_nth_neighbors(1)
37-
... def triangle_loss(xs, ys):
38-
... xs = [x for x in xs if x is not None]
39-
... ys = [y for y in ys if y is not None]
40-
...
41-
... if len(xs) == 2: # we do not have enough points for a triangle
42-
... return xs[1] - xs[0]
43-
...
44-
... N = len(xs) - 2 # number of constructed triangles
45-
... if isinstance(ys[0], Iterable):
46-
... pts = [(x, *y) for x, y in zip(xs, ys)]
47-
... vol = simplex_volume_in_embedding
48-
... else:
49-
... pts = [(x, y) for x, y in zip(xs, ys)]
50-
... vol = volume
51-
... return sum(vol(pts[i:i+3]) for i in range(N)) / N
52-
53-
Or you may define a loss that favours the (local) minima of a function,
54-
assuming that you know your function will have a single float as output.
55-
56-
>>> @uses_nth_neighbors(1)
57-
... def local_minima_resolving_loss(xs, ys):
58-
... dx = xs[2] - xs[1] # the width of the interval of interest
59-
...
60-
... if not ((ys[0] is not None and ys[0] > ys[1])
61-
... or (ys[3] is not None and ys[3] > ys[2])):
62-
... return loss * 100
63-
...
64-
... return loss
65-
"""
66-
def _wrapped(loss_per_interval):
67-
loss_per_interval.nth_neighbors = n
68-
return loss_per_interval
69-
return _wrapped
70-
71-
7220
@uses_nth_neighbors(0)
7321
def uniform_loss(xs, ys):
7422
"""Loss function that samples the domain uniformly.

adaptive/learner/learnerND.py

Lines changed: 133 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,20 @@
1212
import scipy.spatial
1313
from sortedcontainers import SortedKeyList
1414

15-
from adaptive.learner.base_learner import BaseLearner
15+
from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
1616
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
1717
from adaptive.learner.triangulation import (
1818
Triangulation, point_in_simplex, circumsphere,
1919
simplex_volume_in_embedding, fast_det)
2020
from adaptive.utils import restore, cache_latest
2121

2222

23+
def to_list(inp):
24+
if isinstance(inp, Iterable):
25+
return list(inp)
26+
return [inp]
27+
28+
2329
def volume(simplex, ys=None):
2430
# Notice the parameter ys is there so you can use this volume method as
2531
# as loss function
@@ -60,6 +66,71 @@ def default_loss(simplex, ys):
6066
return simplex_volume_in_embedding(pts)
6167

6268

69+
@uses_nth_neighbors(1)
70+
def triangle_loss(simplex, values, neighbors, neighbor_values):
71+
"""
72+
Computes the average of the volumes of the simplex combined with each
73+
neighbouring point.
74+
75+
Parameters
76+
----------
77+
simplex : list of tuples
78+
Each entry is one point of the simplex.
79+
values : list of values
80+
The function values of each of the simplex points.
81+
neighbors : list of tuples
82+
The neighboring points of the simplex, ordered such that simplex[0]
83+
exacly opposes neighbors[0], etc.
84+
neighbor_values : list of values
85+
The function values for each of the neighboring points.
86+
87+
Returns
88+
-------
89+
loss : float
90+
"""
91+
92+
neighbors = [n for n in neighbors if n is not None]
93+
neighbor_values = [v for v in neighbor_values if v is not None]
94+
if len(neighbors) == 0:
95+
return 0
96+
97+
s = [(*x, *to_list(y)) for x, y in zip(simplex, values)]
98+
n = [(*x, *to_list(y)) for x, y in zip(neighbors, neighbor_values)]
99+
100+
return sum(simplex_volume_in_embedding([*s, neighbor])
101+
for neighbor in n) / len(neighbors)
102+
103+
104+
def curvature_loss_function(exploration=0.05):
105+
# XXX: add doc-string!
106+
@uses_nth_neighbors(1)
107+
def curvature_loss(simplex, values, neighbors, neighbor_values):
108+
"""Compute the curvature loss of a simplex.
109+
110+
Parameters
111+
----------
112+
simplex : list of tuples
113+
Each entry is one point of the simplex.
114+
values : list of values
115+
The function values of each of the simplex points.
116+
neighbors : list of tuples
117+
The neighboring points of the simplex, ordered such that simplex[0]
118+
exacly opposes neighbors[0], etc.
119+
neighbor_values : list of values
120+
The function values for each of the neighboring points.
121+
122+
Returns
123+
-------
124+
loss : float
125+
"""
126+
dim = len(simplex[0]) # the number of coordinates
127+
loss_input_volume = volume(simplex)
128+
129+
loss_curvature = triangle_loss(simplex, values, neighbors, neighbor_values)
130+
return (loss_curvature + exploration * loss_input_volume ** ((2 + dim) / dim)) ** (1 / (2 + dim))
131+
return curvature_loss
132+
133+
63134
def choose_point_in_simplex(simplex, transform=None):
64135
"""Choose a new point in inside a simplex.
65136
@@ -70,9 +141,10 @@ def choose_point_in_simplex(simplex, transform=None):
70141
Parameters
71142
----------
72143
simplex : numpy array
73-
The coordinates of a triangle with shape (N+1, N)
144+
The coordinates of a triangle with shape (N+1, N).
74145
transform : N*N matrix
75-
The multiplication to apply to the simplex before choosing the new point
146+
The multiplication to apply to the simplex before choosing
147+
the new point.
76148
77149
Returns
78150
-------
@@ -164,6 +236,17 @@ class LearnerND(BaseLearner):
164236
def __init__(self, func, bounds, loss_per_simplex=None):
165237
self._vdim = None
166238
self.loss_per_simplex = loss_per_simplex or default_loss
239+
240+
if hasattr(self.loss_per_simplex, 'nth_neighbors'):
241+
if self.loss_per_simplex.nth_neighbors > 1:
242+
raise NotImplementedError('The provided loss function wants '
243+
'next-nearest neighboring simplices for the loss computation, '
244+
'this feature is not yet implemented, either use '
245+
'nth_neightbors = 0 or 1')
246+
self.nth_neighbors = self.loss_per_simplex.nth_neighbors
247+
else:
248+
self.nth_neighbors = 0
249+
167250
self.data = OrderedDict()
168251
self.pending_points = set()
169252

@@ -252,14 +335,15 @@ def tri(self):
252335

253336
try:
254337
self._tri = Triangulation(self.points)
255-
self._update_losses(set(), self._tri.simplices)
256-
return self._tri
257338
except ValueError:
258339
# A ValueError is raised if we do not have enough points or
259340
# the provided points are coplanar, so we need more points to
260341
# create a valid triangulation
261342
return None
262343

344+
self._update_losses(set(), self._tri.simplices)
345+
return self._tri
346+
263347
@property
264348
def values(self):
265349
"""Get the values from `data` as a numpy array."""
@@ -326,10 +410,10 @@ def tell_pending(self, point, *, simplex=None):
326410

327411
simplex = tuple(simplex)
328412
simplices = [self.tri.vertex_to_simplices[i] for i in simplex]
329-
neighbours = set.union(*simplices)
413+
neighbors = set.union(*simplices)
330414
# Neighbours also includes the simplex itself
331415

332-
for simpl in neighbours:
416+
for simpl in neighbors:
333417
_, to_add = self._try_adding_pending_point_to_simplex(point, simpl)
334418
if to_add is None:
335419
continue
@@ -394,6 +478,7 @@ def _pop_highest_existing_simplex(self):
394478
# find the simplex with the highest loss, we do need to check that the
395479
# simplex hasn't been deleted yet
396480
while len(self._simplex_queue):
481+
# XXX: Need to add check that the loss is the most recent computed loss
397482
loss, simplex, subsimplex = self._simplex_queue.pop(0)
398483
if (subsimplex is None
399484
and simplex in self.tri.simplices
@@ -449,6 +534,35 @@ def _ask(self):
449534

450535
return self._ask_best_point() # O(log N)
451536

537+
def _compute_loss(self, simplex):
538+
# get the loss
539+
vertices = self.tri.get_vertices(simplex)
540+
values = [self.data[tuple(v)] for v in vertices]
541+
542+
# scale them to a cube with sides 1
543+
vertices = vertices @ self._transform
544+
values = self._output_multiplier * np.array(values)
545+
546+
if self.nth_neighbors == 0:
547+
# compute the loss on the scaled simplex
548+
return float(self.loss_per_simplex(vertices, values))
549+
550+
# We do need the neighbors
551+
neighbors = self.tri.get_opposing_vertices(simplex)
552+
553+
neighbor_points = self.tri.get_vertices(neighbors)
554+
neighbor_values = [self.data.get(x, None) for x in neighbor_points]
555+
556+
for i, point in enumerate(neighbor_points):
557+
if point is not None:
558+
neighbor_points[i] = point @ self._transform
559+
560+
for i, value in enumerate(neighbor_values):
561+
if value is not None:
562+
neighbor_values[i] = self._output_multiplier * value
563+
564+
return float(self.loss_per_simplex(vertices, values, neighbor_points, neighbor_values))
565+
452566
def _update_losses(self, to_delete: set, to_add: set):
453567
# XXX: add the points outside the triangulation to this as well
454568
pending_points_unbound = set()
@@ -461,7 +575,6 @@ def _update_losses(self, to_delete: set, to_add: set):
461575

462576
pending_points_unbound = set(p for p in pending_points_unbound
463577
if p not in self.data)
464-
465578
for simplex in to_add:
466579
loss = self._compute_loss(simplex)
467580
self._losses[simplex] = loss
@@ -476,17 +589,20 @@ def _update_losses(self, to_delete: set, to_add: set):
476589
self._update_subsimplex_losses(
477590
simplex, self._subtriangulations[simplex].simplices)
478591

479-
def _compute_loss(self, simplex):
480-
# get the loss
481-
vertices = self.tri.get_vertices(simplex)
482-
values = [self.data[tuple(v)] for v in vertices]
592+
if self.nth_neighbors:
593+
points_of_added_simplices = set.union(*[set(s) for s in to_add])
594+
neighbors = self.tri.get_simplices_attached_to_points(
595+
points_of_added_simplices) - to_add
596+
for simplex in neighbors:
597+
loss = self._compute_loss(simplex)
598+
self._losses[simplex] = loss
483599

484-
# scale them to a cube with sides 1
485-
vertices = vertices @ self._transform
486-
values = self._output_multiplier * np.array(values)
600+
if simplex not in self._subtriangulations:
601+
self._simplex_queue.add((loss, simplex, None))
602+
continue
487603

488-
# compute the loss on the scaled simplex
489-
return float(self.loss_per_simplex(vertices, values))
604+
self._update_subsimplex_losses(
605+
simplex, self._subtriangulations[simplex].simplices)
490606

491607
def _recompute_all_losses(self):
492608
"""Recompute all losses and pending losses."""

0 commit comments

Comments
 (0)