Skip to content

Commit f68bd81

Browse files
committed
introduce 'uses_nth_neighbors' decorator and rename to 'nth_neighbors'
This abstracts the attribute 'nn_neighbors' away and makes it easier for the user, because one can now just set the 'loss_per_interval' and the 'nn_neighbors' will be set be default.
1 parent ae6af68 commit f68bd81

File tree

2 files changed

+86
-23
lines changed

2 files changed

+86
-23
lines changed

adaptive/learner/learner1D.py

Lines changed: 76 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,61 @@
1515
from ..utils import cache_latest
1616

1717

18+
def uses_nth_neighbors(n):
19+
"""Decorator to specify how many neighboring intervals the loss function uses.
20+
21+
Wraps loss functions to indicate that they expect intervals together
22+
with ``n`` nearest neighbors
23+
24+
The loss function is then guaranteed to receive the data of at least the
25+
N nearest neighbors (``nth_neighbors``) in a dict that tells you what the
26+
neighboring points of these are. And the `~adaptive.Learner1D` will
27+
then make sure that the loss is updated whenever one of the
28+
``nth_neighbors`` changes.
29+
30+
Examples
31+
--------
32+
33+
The next function is a part of the `get_curvature_loss` function.
34+
35+
>>> @uses_nth_neighbors(1)
36+
... def triangle_loss(interval, scale, data, neighbors):
37+
... x_left, x_right = interval
38+
... xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
39+
... # at the boundary, neighbors[<left boundary x>] is (None, <some other x>)
40+
... xs = [x for x in xs if x is not None]
41+
... if len(xs) <= 2:
42+
... return (x_right - x_left) / scale[0]
43+
...
44+
... y_scale = scale[1] or 1
45+
... ys_scaled = [data[x] / y_scale for x in xs]
46+
... xs_scaled = [x / scale[0] for x in xs]
47+
... N = len(xs) - 2
48+
... pts = [(x, y) for x, y in zip(xs_scaled, ys_scaled)]
49+
... return sum(volume(pts[i:i+3]) for i in range(N)) / N
50+
51+
Or you may define a loss that favours the (local) minima of a function.
52+
53+
>>> @uses_nth_neighbors(1)
54+
... def local_minima_resolving_loss(interval, scale, data, neighbors):
55+
... x_left, x_right = interval
56+
... n_left = neighbors[x_left][0]
57+
... n_right = neighbors[x_right][1]
58+
... loss = (x_right - x_left) / scale[0]
59+
...
60+
... if not ((n_left is not None and data[x_left] > data[n_left])
61+
... or (n_right is not None and data[x_right] > data[n_right])):
62+
... return loss * 100
63+
...
64+
... return loss
65+
"""
66+
def _wrapped(loss_per_interval):
67+
loss_per_interval.nth_neighbors = n
68+
return loss_per_interval
69+
return _wrapped
70+
71+
72+
@uses_nth_neighbors(0)
1873
def uniform_loss(interval, scale, data, neighbors):
1974
"""Loss function that samples the domain uniformly.
2075
@@ -36,6 +91,7 @@ def uniform_loss(interval, scale, data, neighbors):
3691
return dx
3792

3893

94+
@uses_nth_neighbors(0)
3995
def default_loss(interval, scale, data, neighbors):
4096
"""Calculate loss on a single interval.
4197
@@ -70,6 +126,7 @@ def _loss_of_multi_interval(xs, ys):
70126
return sum(vol(pts[i:i+3]) for i in range(N)) / N
71127

72128

129+
@uses_nth_neighbors(1)
73130
def triangle_loss(interval, scale, data, neighbors):
74131
x_left, x_right = interval
75132
xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
@@ -85,6 +142,7 @@ def triangle_loss(interval, scale, data, neighbors):
85142

86143

87144
def get_curvature_loss(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02):
145+
@uses_nth_neighbors(1)
88146
def curvature_loss(interval, scale, data, neighbors):
89147
triangle_loss_ = triangle_loss(interval, scale, data, neighbors)
90148
default_loss_ = default_loss(interval, scale, data, neighbors)
@@ -118,8 +176,8 @@ def _get_neighbors_from_list(xs):
118176
return sortedcontainers.SortedDict(neighbors)
119177

120178

121-
def _get_intervals(x, neighbors, nn_neighbors):
122-
nn = nn_neighbors
179+
def _get_intervals(x, neighbors, nth_neighbors):
180+
nn = nth_neighbors
123181
i = neighbors.index(x)
124182
start = max(0, i - nn - 1)
125183
end = min(len(neighbors), i + nn + 2)
@@ -141,10 +199,6 @@ class Learner1D(BaseLearner):
141199
A function that returns the loss for a single interval of the domain.
142200
If not provided, then a default is used, which uses the scaled distance
143201
in the x-y plane as the loss. See the notes for more details.
144-
nn_neighbors : int, default: 0
145-
The number of neighboring intervals that the loss function
146-
takes into account. If ``loss_per_interval`` doesn't use the neighbors
147-
at all, then it should be 0.
148202
149203
Attributes
150204
----------
@@ -170,16 +224,25 @@ class Learner1D(BaseLearner):
170224
A map containing points as keys to its neighbors as a tuple.
171225
At the left ``x_left`` and right ``x_left`` most boundary it has
172226
``x_left: (None, float)`` and ``x_right: (float, None)``.
227+
228+
The `loss_per_interval` function should also have
229+
an attribute `nth_neighbors` that indicates how many of the neighboring
230+
intervals to `interval` are used. If `loss_per_interval` doesn't
231+
have such an attribute, it's assumed that is uses **no** neighboring
232+
intervals. Also see the `uses_nth_neighbors` decorator.
233+
**WARNING**: When modifying the `data` and `neighbors` datastructures
234+
the learner will behave in an undefined way.
173235
"""
174236

175-
def __init__(self, function, bounds, loss_per_interval=None, nn_neighbors=0):
237+
def __init__(self, function, bounds, loss_per_interval=None):
176238
self.function = function
177-
self.nn_neighbors = nn_neighbors
178239

179-
if nn_neighbors == 0:
180-
self.loss_per_interval = loss_per_interval or default_loss
240+
if hasattr(loss_per_interval, 'nth_neighbors'):
241+
self.nth_neighbors = loss_per_interval.nth_neighbors
181242
else:
182-
self.loss_per_interval = loss_per_interval or get_curvature_loss()
243+
self.nth_neighbors = 0
244+
245+
self.loss_per_interval = loss_per_interval or default_loss
183246

184247
# A dict storing the loss function for each interval x_n.
185248
self.losses = {}
@@ -278,10 +341,10 @@ def _update_losses(self, x, real=True):
278341

279342
if real:
280343
# We need to update all interpolated losses in the interval
281-
# (x_left, x), (x, x_right) and the nn_neighbors nearest
344+
# (x_left, x), (x, x_right) and the nth_neighbors nearest
282345
# neighboring intervals. Since the addition of the
283346
# point 'x' could change their loss.
284-
for ival in _get_intervals(x, self.neighbors, self.nn_neighbors):
347+
for ival in _get_intervals(x, self.neighbors, self.nth_neighbors):
285348
self._update_interpolated_loss_in_interval(*ival)
286349

287350
# Since 'x' is in between (x_left, x_right),

adaptive/tests/test_learner1d.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -347,19 +347,19 @@ def test_curvature_loss():
347347
def f(x):
348348
return np.tanh(20*x)
349349

350-
for n in [0, 1]:
351-
learner = Learner1D(f, (-1, 1),
352-
loss_per_interval=get_curvature_loss(), nn_neighbors=n)
353-
simple(learner, goal=lambda l: l.npoints > 100)
354-
assert learner.npoints > 100
350+
loss = get_curvature_loss()
351+
assert loss.nth_neighbors == 1
352+
learner = Learner1D(f, (-1, 1), loss_per_interval=loss)
353+
simple(learner, goal=lambda l: l.npoints > 100)
354+
assert learner.npoints > 100
355355

356356

357357
def test_curvature_loss_vectors():
358358
def f(x):
359359
return np.tanh(20*x), np.tanh(20*(x-0.4))
360360

361-
for n in [0, 1]:
362-
learner = Learner1D(f, (-1, 1),
363-
loss_per_interval=get_curvature_loss(), nn_neighbors=n)
364-
simple(learner, goal=lambda l: l.npoints > 100)
365-
assert learner.npoints > 100
361+
loss = get_curvature_loss()
362+
assert loss.nth_neighbors == 1
363+
learner = Learner1D(f, (-1, 1), loss_per_interval=loss)
364+
simple(learner, goal=lambda l: l.npoints > 100)
365+
assert learner.npoints > 100

0 commit comments

Comments
 (0)