Skip to content

Commit 8816287

Browse files
committed
remove 'loss_depends_on_neighbours' and replace by 'nn_neighbors'
This makes the code work with any number of neighbors. Now the new triangle loss even works with nn_neigbors=0. I also added that we pass the neighbors in all loss_per_interval functions.
1 parent 0360e79 commit 8816287

File tree

2 files changed

+49
-48
lines changed

2 files changed

+49
-48
lines changed

adaptive/learner/learner1D.py

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..utils import cache_latest
1616

1717

18-
def uniform_loss(interval, scale, function_values):
18+
def uniform_loss(interval, scale, function_values, neighbors):
1919
"""Loss function that samples the domain uniformly.
2020
2121
Works with `~adaptive.Learner1D` only.
@@ -36,7 +36,7 @@ def uniform_loss(interval, scale, function_values):
3636
return dx
3737

3838

39-
def default_loss(interval, scale, function_values):
39+
def default_loss(interval, scale, function_values, neighbors):
4040
"""Calculate loss on a single interval.
4141
4242
Currently returns the rescaled length of the interval. If one of the
@@ -70,12 +70,9 @@ def _loss_of_multi_interval(xs, ys):
7070
return sum(vol(pts[i:i+3]) for i in range(N)) / N
7171

7272

73-
def triangle_loss(interval, neighbours, scale, function_values):
73+
def triangle_loss(interval, scale, function_values, neighbors):
7474
x_left, x_right = interval
75-
neighbour_left, neighbour_right = neighbours
76-
xs = [neighbour_left, x_left, x_right, neighbour_right]
77-
# The neighbours could be None if we are at the boundary, in that case we
78-
# have to filter this out
75+
xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
7976
xs = [x for x in xs if x is not None]
8077

8178
if len(xs) <= 2:
@@ -88,9 +85,9 @@ def triangle_loss(interval, neighbours, scale, function_values):
8885

8986

9087
def get_curvature_loss(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02):
91-
def curvature_loss(interval, neighbours, scale, function_values):
92-
triangle_loss_ = triangle_loss(interval, neighbours, scale, function_values)
93-
default_loss_ = default_loss(interval, scale, function_values)
88+
def curvature_loss(interval, scale, function_values, neighbors):
89+
triangle_loss_ = triangle_loss(interval, scale, function_values, neighbors)
90+
default_loss_ = default_loss(interval, scale, function_values, neighbors)
9491
dx = (interval[1] - interval[0]) / scale[0]
9592
return (area_factor * (triangle_loss_**0.5)
9693
+ euclid_factor * default_loss_
@@ -121,6 +118,15 @@ def _get_neighbors_from_list(xs):
121118
return sortedcontainers.SortedDict(neighbors)
122119

123120

121+
def _get_intervals(x, neighbors, nn_neighbors):
122+
nn = nn_neighbors
123+
i = neighbors.index(x)
124+
start = max(0, i - nn - 1)
125+
end = min(len(neighbors), i + nn + 2)
126+
points = neighbors.keys()[start:end]
127+
return list(zip(points, points[1:]))
128+
129+
124130
class Learner1D(BaseLearner):
125131
"""Learns and predicts a function 'f:ℝ → ℝ^N'.
126132
@@ -135,6 +141,10 @@ class Learner1D(BaseLearner):
135141
A function that returns the loss for a single interval of the domain.
136142
If not provided, then a default is used, which uses the scaled distance
137143
in the x-y plane as the loss. See the notes for more details.
144+
nn_neighbors : int, default: 0
145+
The number of neighboring intervals that the loss function
146+
takes into account. If ``loss_per_interval`` doesn't use the neighbors
147+
at all, then it should be 0.
138148
139149
Attributes
140150
----------
@@ -145,9 +155,9 @@ class Learner1D(BaseLearner):
145155
146156
Notes
147157
-----
148-
`loss_per_interval` takes 3 parameters: ``interval``, ``scale``, and
149-
``function_values``, and returns a scalar; the loss over the interval.
150-
158+
`loss_per_interval` takes 4 parameters: ``interval``, ``scale``,
159+
``data``, and ``neighbors``, and returns a scalar; the loss over
160+
the interval.
151161
interval : (float, float)
152162
The bounds of the interval.
153163
scale : (float, float)
@@ -156,16 +166,18 @@ class Learner1D(BaseLearner):
156166
function_values : dict(float → float)
157167
A map containing evaluated function values. It is guaranteed
158168
to have values for both of the points in 'interval'.
169+
neighbors : dict(float → (float, float))
170+
A map containing points as keys to its neighbors as a tuple.
159171
"""
160172

161-
def __init__(self, function, bounds, loss_per_interval=None, loss_depends_on_neighbours=False):
173+
def __init__(self, function, bounds, loss_per_interval=None, nn_neighbors=0):
162174
self.function = function
163-
self._loss_depends_on_neighbours = loss_depends_on_neighbours
175+
self.nn_neighbors = nn_neighbors
164176

165-
if loss_depends_on_neighbours:
166-
self.loss_per_interval = loss_per_interval or get_curvature_loss()
167-
else:
177+
if nn_neighbors == 0:
168178
self.loss_per_interval = loss_per_interval or default_loss
179+
else:
180+
self.loss_per_interval = loss_per_interval or get_curvature_loss()
169181

170182
# A dict storing the loss function for each interval x_n.
171183
self.losses = {}
@@ -230,15 +242,8 @@ def _get_loss_in_interval(self, x_left, x_right):
230242
return 0
231243

232244
# we need to compute the loss for this interval
233-
interval = (x_left, x_right)
234-
if self._loss_depends_on_neighbours:
235-
neighbour_left = self.neighbors.get(x_left, (None, None))[0]
236-
neighbour_right = self.neighbors.get(x_right, (None, None))[1]
237-
neighbours = neighbour_left, neighbour_right
238-
return self.loss_per_interval(interval, neighbours,
239-
self._scale, self.data)
240-
else:
241-
return self.loss_per_interval(interval, self._scale, self.data)
245+
return self.loss_per_interval(
246+
(x_left, x_right), self._scale, self.data, self.neighbors)
242247

243248

244249
def _update_interpolated_loss_in_interval(self, x_left, x_right):
@@ -271,17 +276,11 @@ def _update_losses(self, x, real=True):
271276

272277
if real:
273278
# We need to update all interpolated losses in the interval
274-
# (x_left, x) and (x, x_right). Since the addition of the point
275-
# 'x' could change their loss.
276-
self._update_interpolated_loss_in_interval(x_left, x)
277-
self._update_interpolated_loss_in_interval(x, x_right)
278-
279-
# if the loss depends on the neighbors we should also update those losses
280-
if self._loss_depends_on_neighbours:
281-
neighbour_left = self.neighbors.get(x_left, (None, None))[0]
282-
neighbour_right = self.neighbors.get(x_right, (None, None))[1]
283-
self._update_interpolated_loss_in_interval(neighbour_left, x_left)
284-
self._update_interpolated_loss_in_interval(x_right, neighbour_right)
279+
# (x_left, x), (x, x_right) and the nn_neighbors nearest
280+
# neighboring intervals. Since the addition of the
281+
# point 'x' could change their loss.
282+
for ival in _get_intervals(x, self.neighbors, self.nn_neighbors):
283+
self._update_interpolated_loss_in_interval(*ival)
285284

286285
# Since 'x' is in between (x_left, x_right),
287286
# we get rid of the interval.
@@ -427,10 +426,8 @@ def tell_many(self, xs, ys, *, force=False):
427426

428427
# The the losses for the "real" intervals.
429428
self.losses = {}
430-
for x_left, x_right in intervals:
431-
self.losses[x_left, x_right] = (
432-
self._get_loss_in_interval(x_left, x_right)
433-
if x_right - x_left >= self._dx_eps else 0)
429+
for ival in intervals:
430+
self.losses[ival] = self._get_loss_in_interval(*ival)
434431

435432
# List with "real" intervals that have interpolated intervals inside
436433
to_interpolate = []

adaptive/tests/test_learner1d.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -347,15 +347,19 @@ def test_curvature_loss():
347347
def f(x):
348348
return np.tanh(20*x)
349349

350-
learner = Learner1D(f, (-1, 1), loss_per_interval=get_curvature_loss(), loss_depends_on_neighbours=True)
351-
simple(learner, goal=lambda l: l.npoints > 100)
352-
# assert this is reached without error
350+
for n in [0, 1]:
351+
learner = Learner1D(f, (-1, 1),
352+
loss_per_interval=get_curvature_loss(), nn_neighbors=n)
353+
simple(learner, goal=lambda l: l.npoints > 100)
354+
assert learner.npoints > 100
353355

354356

355357
def test_curvature_loss_vectors():
356358
def f(x):
357359
return np.tanh(20*x), np.tanh(20*(x-0.4))
358360

359-
learner = Learner1D(f, (-1, 1), loss_per_interval=get_curvature_loss(), loss_depends_on_neighbours=True)
360-
simple(learner, goal=lambda l: l.npoints > 100)
361-
assert learner.npoints > 100
361+
for n in [0, 1]:
362+
learner = Learner1D(f, (-1, 1),
363+
loss_per_interval=get_curvature_loss(), nn_neighbors=n)
364+
simple(learner, goal=lambda l: l.npoints > 100)
365+
assert learner.npoints > 100

0 commit comments

Comments
 (0)