Skip to content

Commit 7408ed5

Browse files
committed
Merge branch '119-add-second-order-loss-to-adaptive' into 'master'
Resolve "(Learner1D) add possibility to use the direct neighbors in the loss" Closes #119 See merge request qt/adaptive!131
2 parents ccba17d + fde1774 commit 7408ed5

File tree

8 files changed

+310
-70
lines changed

8 files changed

+310
-70
lines changed

adaptive/learner/learner1D.py

Lines changed: 192 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,71 @@
33
import heapq
44
import itertools
55
import math
6+
from collections import Iterable
67

78
import numpy as np
89
import sortedcontainers
910

1011
from .base_learner import BaseLearner
12+
from .learnerND import volume
13+
from .triangulation import simplex_volume_in_embedding
1114
from ..notebook_integration import ensure_holoviews
1215
from ..utils import cache_latest
1316

1417

15-
def uniform_loss(interval, scale, function_values):
18+
def uses_nth_neighbors(n):
19+
"""Decorator to specify how many neighboring intervals the loss function uses.
20+
21+
Wraps loss functions to indicate that they expect intervals together
22+
with ``n`` nearest neighbors
23+
24+
The loss function will then receive the data of the N nearest neighbors
25+
(``nth_neighbors``) aling with the data of the interval itself in a dict.
26+
The `~adaptive.Learner1D` will also make sure that the loss is updated
27+
whenever one of the ``nth_neighbors`` changes.
28+
29+
Examples
30+
--------
31+
32+
The next function is a part of the `curvature_loss_function` function.
33+
34+
>>> @uses_nth_neighbors(1)
35+
...def triangle_loss(xs, ys):
36+
... xs = [x for x in xs if x is not None]
37+
... ys = [y for y in ys if y is not None]
38+
...
39+
... if len(xs) == 2: # we do not have enough points for a triangle
40+
... return xs[1] - xs[0]
41+
...
42+
... N = len(xs) - 2 # number of constructed triangles
43+
... if isinstance(ys[0], Iterable):
44+
... pts = [(x, *y) for x, y in zip(xs, ys)]
45+
... vol = simplex_volume_in_embedding
46+
... else:
47+
... pts = [(x, y) for x, y in zip(xs, ys)]
48+
... vol = volume
49+
... return sum(vol(pts[i:i+3]) for i in range(N)) / N
50+
51+
Or you may define a loss that favours the (local) minima of a function,
52+
assuming that you know your function will have a single float as output.
53+
54+
>>> @uses_nth_neighbors(1)
55+
... def local_minima_resolving_loss(xs, ys):
56+
... dx = xs[2] - xs[1] # the width of the interval of interest
57+
...
58+
... if not ((ys[0] is not None and ys[0] > ys[1])
59+
... or (ys[3] is not None and ys[3] > ys[2])):
60+
... return loss * 100
61+
...
62+
... return loss
63+
"""
64+
def _wrapped(loss_per_interval):
65+
loss_per_interval.nth_neighbors = n
66+
return loss_per_interval
67+
return _wrapped
68+
69+
@uses_nth_neighbors(0)
70+
def uniform_loss(xs, ys):
1671
"""Loss function that samples the domain uniformly.
1772
1873
Works with `~adaptive.Learner1D` only.
@@ -27,33 +82,58 @@ def uniform_loss(interval, scale, function_values):
2782
... loss_per_interval=uniform_sampling_1d)
2883
>>>
2984
"""
30-
x_left, x_right = interval
31-
x_scale, _ = scale
32-
dx = (x_right - x_left) / x_scale
85+
dx = xs[1] - xs[0]
3386
return dx
3487

3588

36-
def default_loss(interval, scale, function_values):
89+
@uses_nth_neighbors(0)
90+
def default_loss(xs, ys):
3791
"""Calculate loss on a single interval.
3892
3993
Currently returns the rescaled length of the interval. If one of the
4094
y-values is missing, returns 0 (so the intervals with missing data are
4195
never touched. This behavior should be improved later.
4296
"""
43-
x_left, x_right = interval
44-
y_right, y_left = function_values[x_right], function_values[x_left]
45-
x_scale, y_scale = scale
46-
dx = (x_right - x_left) / x_scale
47-
if y_scale == 0:
48-
loss = dx
97+
dx = xs[1] - xs[0]
98+
if isinstance(ys[0], Iterable):
99+
dy = [abs(a-b) for a, b in zip(*ys)]
100+
return np.hypot(dx, dy).max()
101+
else:
102+
dy = ys[1] - ys[0]
103+
return np.hypot(dx, dy)
104+
105+
106+
@uses_nth_neighbors(1)
107+
def triangle_loss(xs, ys):
108+
xs = [x for x in xs if x is not None]
109+
ys = [y for y in ys if y is not None]
110+
111+
if len(xs) == 2: # we do not have enough points for a triangle
112+
return xs[1] - xs[0]
113+
114+
N = len(xs) - 2 # number of constructed triangles
115+
if isinstance(ys[0], Iterable):
116+
pts = [(x, *y) for x, y in zip(xs, ys)]
117+
vol = simplex_volume_in_embedding
49118
else:
50-
dy = (y_right - y_left) / y_scale
51-
try:
52-
len(dy)
53-
loss = np.hypot(dx, dy).max()
54-
except TypeError:
55-
loss = math.hypot(dx, dy)
56-
return loss
119+
pts = [(x, y) for x, y in zip(xs, ys)]
120+
vol = volume
121+
return sum(vol(pts[i:i+3]) for i in range(N)) / N
122+
123+
124+
def curvature_loss_function(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02):
125+
@uses_nth_neighbors(1)
126+
def curvature_loss(xs, ys):
127+
xs_middle = xs[1:3]
128+
ys_middle = xs[1:3]
129+
130+
triangle_loss_ = triangle_loss(xs, ys)
131+
default_loss_ = default_loss(xs_middle, ys_middle)
132+
dx = xs_middle[0] - xs_middle[0]
133+
return (area_factor * (triangle_loss_**0.5)
134+
+ euclid_factor * default_loss_
135+
+ horizontal_factor * dx)
136+
return curvature_loss
57137

58138

59139
def linspace(x_left, x_right, n):
@@ -79,6 +159,15 @@ def _get_neighbors_from_list(xs):
79159
return sortedcontainers.SortedDict(neighbors)
80160

81161

162+
def _get_intervals(x, neighbors, nth_neighbors):
163+
nn = nth_neighbors
164+
i = neighbors.index(x)
165+
start = max(0, i - nn - 1)
166+
end = min(len(neighbors), i + nn + 2)
167+
points = neighbors.keys()[start:end]
168+
return list(zip(points, points[1:]))
169+
170+
82171
class Learner1D(BaseLearner):
83172
"""Learns and predicts a function 'f:ℝ → ℝ^N'.
84173
@@ -103,21 +192,34 @@ class Learner1D(BaseLearner):
103192
104193
Notes
105194
-----
106-
`loss_per_interval` takes 3 parameters: ``interval``, ``scale``, and
107-
``function_values``, and returns a scalar; the loss over the interval.
108-
109-
interval : (float, float)
110-
The bounds of the interval.
111-
scale : (float, float)
112-
The x and y scale over all the intervals, useful for rescaling the
113-
interval loss.
114-
function_values : dict(float → float)
115-
A map containing evaluated function values. It is guaranteed
116-
to have values for both of the points in 'interval'.
195+
`loss_per_interval` takes 2 parameters: ``xs`` and ``ys``, and returns a
196+
scalar; the loss over the interval.
197+
xs : tuple of floats
198+
The x values of the interval, if `nth_neighbors` is greater than zero it
199+
also contains the x-values of the neighbors of the interval, in ascending
200+
order. The interval we want to know the loss of is then the middle
201+
interval. If no neighbor is available (at the edges of the domain) then
202+
`None` will take the place of the x-value of the neighbor.
203+
ys : tuple of function values
204+
The output values of the function when evaluated at the `xs`. This is
205+
either a float or a tuple of floats in the case of vector output.
206+
207+
208+
The `loss_per_interval` function may also have an attribute `nth_neighbors`
209+
that indicates how many of the neighboring intervals to `interval` are used.
210+
If `loss_per_interval` doesn't have such an attribute, it's assumed that is
211+
uses **no** neighboring intervals. Also see the `uses_nth_neighbors`
212+
decorator for more information.
117213
"""
118214

119215
def __init__(self, function, bounds, loss_per_interval=None):
120216
self.function = function
217+
218+
if hasattr(loss_per_interval, 'nth_neighbors'):
219+
self.nth_neighbors = loss_per_interval.nth_neighbors
220+
else:
221+
self.nth_neighbors = 0
222+
121223
self.loss_per_interval = loss_per_interval or default_loss
122224

123225
# A dict storing the loss function for each interval x_n.
@@ -176,25 +278,60 @@ def loss(self, real=True):
176278
losses = self.losses if real else self.losses_combined
177279
return max(losses.values()) if len(losses) > 0 else float('inf')
178280

281+
def _scale_x(self, x):
282+
if x is None:
283+
return None
284+
return x / self._scale[0]
285+
286+
def _scale_y(self, y):
287+
if y is None:
288+
return None
289+
y_scale = self._scale[1] or 1
290+
return y / y_scale
291+
292+
def _get_point_by_index(self, ind):
293+
if ind < 0 or ind >= len(self.neighbors):
294+
return None
295+
return self.neighbors.keys()[ind]
296+
297+
def _get_loss_in_interval(self, x_left, x_right):
298+
assert x_left is not None and x_right is not None
299+
300+
if x_right - x_left < self._dx_eps:
301+
return 0
302+
303+
nn = self.nth_neighbors
304+
i = self.neighbors.index(x_left)
305+
start = i - nn
306+
end = i + nn + 2
307+
308+
xs = [self._get_point_by_index(i) for i in range(start, end)]
309+
ys = [self.data.get(x, None) for x in xs]
310+
311+
xs_scaled = tuple(self._scale_x(x) for x in xs)
312+
ys_scaled = tuple(self._scale_y(y) for y in ys)
313+
314+
# we need to compute the loss for this interval
315+
return self.loss_per_interval(xs_scaled, ys_scaled)
316+
179317
def _update_interpolated_loss_in_interval(self, x_left, x_right):
180-
if x_left is not None and x_right is not None:
181-
dx = x_right - x_left
182-
if dx < self._dx_eps:
183-
loss = 0
184-
else:
185-
loss = self.loss_per_interval((x_left, x_right),
186-
self._scale, self.data)
187-
self.losses[x_left, x_right] = loss
188-
189-
# Iterate over all interpolated intervals in between
190-
# x_left and x_right and set the newly interpolated loss.
191-
a, b = x_left, None
192-
while b != x_right:
193-
b = self.neighbors_combined[a][1]
194-
self.losses_combined[a, b] = (b - a) * loss / dx
195-
a = b
318+
if x_left is None or x_right is None:
319+
return
320+
321+
loss = self._get_loss_in_interval(x_left, x_right)
322+
self.losses[x_left, x_right] = loss
323+
324+
# Iterate over all interpolated intervals in between
325+
# x_left and x_right and set the newly interpolated loss.
326+
a, b = x_left, None
327+
dx = x_right - x_left
328+
while b != x_right:
329+
b = self.neighbors_combined[a][1]
330+
self.losses_combined[a, b] = (b - a) * loss / dx
331+
a = b
196332

197333
def _update_losses(self, x, real=True):
334+
"""Update all losses that depend on x"""
198335
# When we add a new point x, we should update the losses
199336
# (x_left, x_right) are the "real" neighbors of 'x'.
200337
x_left, x_right = self._find_neighbors(x, self.neighbors)
@@ -207,10 +344,11 @@ def _update_losses(self, x, real=True):
207344

208345
if real:
209346
# We need to update all interpolated losses in the interval
210-
# (x_left, x) and (x, x_right). Since the addition of the point
211-
# 'x' could change their loss.
212-
self._update_interpolated_loss_in_interval(x_left, x)
213-
self._update_interpolated_loss_in_interval(x, x_right)
347+
# (x_left, x), (x, x_right) and the nth_neighbors nearest
348+
# neighboring intervals. Since the addition of the
349+
# point 'x' could change their loss.
350+
for ival in _get_intervals(x, self.neighbors, self.nth_neighbors):
351+
self._update_interpolated_loss_in_interval(*ival)
214352

215353
# Since 'x' is in between (x_left, x_right),
216354
# we get rid of the interval.
@@ -284,6 +422,9 @@ def tell(self, x, y):
284422
if x in self.data:
285423
# The point is already evaluated before
286424
return
425+
if y is None:
426+
raise TypeError("Y-value may not be None, use learner.tell_pending(x)"
427+
"to indicate that this value is currently being calculated")
287428

288429
# either it is a float/int, if not, try casting to a np.array
289430
if not isinstance(y, (float, int)):
@@ -356,10 +497,8 @@ def tell_many(self, xs, ys, *, force=False):
356497

357498
# The the losses for the "real" intervals.
358499
self.losses = {}
359-
for x_left, x_right in intervals:
360-
self.losses[x_left, x_right] = (
361-
self.loss_per_interval((x_left, x_right), self._scale, self.data)
362-
if x_right - x_left >= self._dx_eps else 0)
500+
for ival in intervals:
501+
self.losses[ival] = self._get_loss_in_interval(*ival)
363502

364503
# List with "real" intervals that have interpolated intervals inside
365504
to_interpolate = []

adaptive/learner/learnerND.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,19 @@
1313

1414
from ..notebook_integration import ensure_holoviews, ensure_plotly
1515
from .triangulation import (Triangulation, point_in_simplex,
16-
circumsphere, simplex_volume_in_embedding)
16+
circumsphere, simplex_volume_in_embedding,
17+
fast_det)
1718
from ..utils import restore, cache_latest
1819

1920

2021
def volume(simplex, ys=None):
2122
# Notice the parameter ys is there so you can use this volume method as
2223
# as loss function
23-
matrix = np.array(np.subtract(simplex[:-1], simplex[-1]), dtype=float)
24-
dim = len(simplex) - 1
24+
matrix = np.subtract(simplex[:-1], simplex[-1], dtype=float)
2525

2626
# See https://www.jstor.org/stable/2315353
27-
vol = np.abs(np.linalg.det(matrix)) / np.math.factorial(dim)
27+
dim = len(simplex) - 1
28+
vol = np.abs(fast_det(matrix)) / np.math.factorial(dim)
2829
return vol
2930

3031

0 commit comments

Comments
 (0)