Skip to content

Commit 0360e79

Browse files
jhoofwijkbasnijholt
authored andcommitted
added a curvature_loss function to learner1D
1 parent ccba17d commit 0360e79

File tree

3 files changed

+114
-22
lines changed

3 files changed

+114
-22
lines changed

adaptive/learner/learner1D.py

Lines changed: 90 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
import heapq
44
import itertools
55
import math
6+
from collections import Iterable
67

78
import numpy as np
89
import sortedcontainers
910

1011
from .base_learner import BaseLearner
12+
from .learnerND import volume
13+
from .triangulation import simplex_volume_in_embedding
1114
from ..notebook_integration import ensure_holoviews
1215
from ..utils import cache_latest
1316

@@ -56,6 +59,45 @@ def default_loss(interval, scale, function_values):
5659
return loss
5760

5861

62+
def _loss_of_multi_interval(xs, ys):
63+
N = len(xs) - 2
64+
if isinstance(ys[0], Iterable):
65+
pts = [(x, *y) for x, y in zip(xs, ys)]
66+
vol = simplex_volume_in_embedding
67+
else:
68+
pts = [(x, y) for x, y in zip(xs, ys)]
69+
vol = volume
70+
return sum(vol(pts[i:i+3]) for i in range(N)) / N
71+
72+
73+
def triangle_loss(interval, neighbours, scale, function_values):
74+
x_left, x_right = interval
75+
neighbour_left, neighbour_right = neighbours
76+
xs = [neighbour_left, x_left, x_right, neighbour_right]
77+
# The neighbours could be None if we are at the boundary, in that case we
78+
# have to filter this out
79+
xs = [x for x in xs if x is not None]
80+
81+
if len(xs) <= 2:
82+
return (x_right - x_left) / scale[0]
83+
else:
84+
y_scale = scale[1] or 1
85+
ys_scaled = [function_values[x] / y_scale for x in xs]
86+
xs_scaled = [x / scale[0] for x in xs]
87+
return _loss_of_multi_interval(xs_scaled, ys_scaled)
88+
89+
90+
def get_curvature_loss(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02):
91+
def curvature_loss(interval, neighbours, scale, function_values):
92+
triangle_loss_ = triangle_loss(interval, neighbours, scale, function_values)
93+
default_loss_ = default_loss(interval, scale, function_values)
94+
dx = (interval[1] - interval[0]) / scale[0]
95+
return (area_factor * (triangle_loss_**0.5)
96+
+ euclid_factor * default_loss_
97+
+ horizontal_factor * dx)
98+
return curvature_loss
99+
100+
59101
def linspace(x_left, x_right, n):
60102
"""This is equivalent to
61103
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
@@ -116,9 +158,14 @@ class Learner1D(BaseLearner):
116158
to have values for both of the points in 'interval'.
117159
"""
118160

119-
def __init__(self, function, bounds, loss_per_interval=None):
161+
def __init__(self, function, bounds, loss_per_interval=None, loss_depends_on_neighbours=False):
120162
self.function = function
121-
self.loss_per_interval = loss_per_interval or default_loss
163+
self._loss_depends_on_neighbours = loss_depends_on_neighbours
164+
165+
if loss_depends_on_neighbours:
166+
self.loss_per_interval = loss_per_interval or get_curvature_loss()
167+
else:
168+
self.loss_per_interval = loss_per_interval or default_loss
122169

123170
# A dict storing the loss function for each interval x_n.
124171
self.losses = {}
@@ -176,25 +223,42 @@ def loss(self, real=True):
176223
losses = self.losses if real else self.losses_combined
177224
return max(losses.values()) if len(losses) > 0 else float('inf')
178225

226+
def _get_loss_in_interval(self, x_left, x_right):
227+
assert x_left is not None and x_right is not None
228+
229+
if x_right - x_left < self._dx_eps:
230+
return 0
231+
232+
# we need to compute the loss for this interval
233+
interval = (x_left, x_right)
234+
if self._loss_depends_on_neighbours:
235+
neighbour_left = self.neighbors.get(x_left, (None, None))[0]
236+
neighbour_right = self.neighbors.get(x_right, (None, None))[1]
237+
neighbours = neighbour_left, neighbour_right
238+
return self.loss_per_interval(interval, neighbours,
239+
self._scale, self.data)
240+
else:
241+
return self.loss_per_interval(interval, self._scale, self.data)
242+
243+
179244
def _update_interpolated_loss_in_interval(self, x_left, x_right):
180-
if x_left is not None and x_right is not None:
181-
dx = x_right - x_left
182-
if dx < self._dx_eps:
183-
loss = 0
184-
else:
185-
loss = self.loss_per_interval((x_left, x_right),
186-
self._scale, self.data)
187-
self.losses[x_left, x_right] = loss
188-
189-
# Iterate over all interpolated intervals in between
190-
# x_left and x_right and set the newly interpolated loss.
191-
a, b = x_left, None
192-
while b != x_right:
193-
b = self.neighbors_combined[a][1]
194-
self.losses_combined[a, b] = (b - a) * loss / dx
195-
a = b
245+
if x_left is None or x_right is None:
246+
return
247+
248+
loss = self._get_loss_in_interval(x_left, x_right)
249+
self.losses[x_left, x_right] = loss
250+
251+
# Iterate over all interpolated intervals in between
252+
# x_left and x_right and set the newly interpolated loss.
253+
a, b = x_left, None
254+
dx = x_right - x_left
255+
while b != x_right:
256+
b = self.neighbors_combined[a][1]
257+
self.losses_combined[a, b] = (b - a) * loss / dx
258+
a = b
196259

197260
def _update_losses(self, x, real=True):
261+
"""Update all losses that depend on x"""
198262
# When we add a new point x, we should update the losses
199263
# (x_left, x_right) are the "real" neighbors of 'x'.
200264
x_left, x_right = self._find_neighbors(x, self.neighbors)
@@ -212,6 +276,13 @@ def _update_losses(self, x, real=True):
212276
self._update_interpolated_loss_in_interval(x_left, x)
213277
self._update_interpolated_loss_in_interval(x, x_right)
214278

279+
# if the loss depends on the neighbors we should also update those losses
280+
if self._loss_depends_on_neighbours:
281+
neighbour_left = self.neighbors.get(x_left, (None, None))[0]
282+
neighbour_right = self.neighbors.get(x_right, (None, None))[1]
283+
self._update_interpolated_loss_in_interval(neighbour_left, x_left)
284+
self._update_interpolated_loss_in_interval(x_right, neighbour_right)
285+
215286
# Since 'x' is in between (x_left, x_right),
216287
# we get rid of the interval.
217288
self.losses.pop((x_left, x_right), None)
@@ -358,7 +429,7 @@ def tell_many(self, xs, ys, *, force=False):
358429
self.losses = {}
359430
for x_left, x_right in intervals:
360431
self.losses[x_left, x_right] = (
361-
self.loss_per_interval((x_left, x_right), self._scale, self.data)
432+
self._get_loss_in_interval(x_left, x_right)
362433
if x_right - x_left >= self._dx_eps else 0)
363434

364435
# List with "real" intervals that have interpolated intervals inside

adaptive/learner/triangulation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ def simplex_volume_in_embedding(vertices) -> float:
229229
coeff = - (-2) ** (num_verts-1) * factorial(num_verts-1) ** 2
230230
vol_square = np.linalg.det(sq_dists_mat) / coeff
231231

232-
if vol_square <= 0:
232+
if vol_square < 0:
233+
if abs(vol_square) < 1e-15:
234+
return 0
233235
raise ValueError('Provided vertices do not form a simplex')
234236

235237
return np.sqrt(vol_square)

adaptive/tests/test_learner1d.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from ..learner import Learner1D
7+
from ..learner.learner1D import get_curvature_loss
78
from ..runner import simple
89

910

@@ -120,9 +121,9 @@ def test_termination_on_discontinuities():
120121
smallest_interval = min(abs(a - b) for a, b in learner.losses.keys())
121122
assert smallest_interval >= np.finfo(float).eps
122123

123-
learner = _run_on_discontinuity(0.5E3, (-1E3, 1E3))
124+
learner = _run_on_discontinuity(0.5e3, (-1e3, 1e3))
124125
smallest_interval = min(abs(a - b) for a, b in learner.losses.keys())
125-
assert smallest_interval >= 0.5E3 * np.finfo(float).eps
126+
assert smallest_interval >= 0.5e3 * np.finfo(float).eps
126127

127128

128129
def test_order_adding_points():
@@ -340,3 +341,21 @@ def _random_run(learner, learner2, scale_doubling=True):
340341
learner2 = Learner1D(f, bounds=(-1, 1))
341342
_random_run(learner, learner2, scale_doubling=True)
342343
test_equal(learner, learner2)
344+
345+
346+
def test_curvature_loss():
347+
def f(x):
348+
return np.tanh(20*x)
349+
350+
learner = Learner1D(f, (-1, 1), loss_per_interval=get_curvature_loss(), loss_depends_on_neighbours=True)
351+
simple(learner, goal=lambda l: l.npoints > 100)
352+
# assert this is reached without error
353+
354+
355+
def test_curvature_loss_vectors():
356+
def f(x):
357+
return np.tanh(20*x), np.tanh(20*(x-0.4))
358+
359+
learner = Learner1D(f, (-1, 1), loss_per_interval=get_curvature_loss(), loss_depends_on_neighbours=True)
360+
simple(learner, goal=lambda l: l.npoints > 100)
361+
assert learner.npoints > 100

0 commit comments

Comments
 (0)