15
15
from ..utils import cache_latest
16
16
17
17
18
+ def uses_nth_neighbors (n ):
19
+ """Decorator to specify how many neighboring intervals the loss function uses.
20
+
21
+ Wraps loss functions to indicate that they expect intervals together
22
+ with ``n`` nearest neighbors
23
+
24
+ The loss function is then guaranteed to receive the data of at least the
25
+ N nearest neighbors (``nth_neighbors``) in a dict that tells you what the
26
+ neighboring points of these are. And the `~adaptive.Learner1D` will
27
+ then make sure that the loss is updated whenever one of the
28
+ ``nth_neighbors`` changes.
29
+
30
+ Examples
31
+ --------
32
+
33
+ The next function is a part of the `get_curvature_loss` function.
34
+
35
+ >>> @uses_nth_neighbors(1)
36
+ ... def triangle_loss(interval, scale, data, neighbors):
37
+ ... x_left, x_right = interval
38
+ ... xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
39
+ ... # at the boundary, neighbors[<left boundary x>] is (None, <some other x>)
40
+ ... xs = [x for x in xs if x is not None]
41
+ ... if len(xs) <= 2:
42
+ ... return (x_right - x_left) / scale[0]
43
+ ...
44
+ ... y_scale = scale[1] or 1
45
+ ... ys_scaled = [data[x] / y_scale for x in xs]
46
+ ... xs_scaled = [x / scale[0] for x in xs]
47
+ ... N = len(xs) - 2
48
+ ... pts = [(x, y) for x, y in zip(xs_scaled, ys_scaled)]
49
+ ... return sum(volume(pts[i:i+3]) for i in range(N)) / N
50
+
51
+ Or you may define a loss that favours the (local) minima of a function.
52
+
53
+ >>> @uses_nth_neighbors(1)
54
+ ... def local_minima_resolving_loss(interval, scale, data, neighbors):
55
+ ... x_left, x_right = interval
56
+ ... n_left = neighbors[x_left][0]
57
+ ... n_right = neighbors[x_right][1]
58
+ ... loss = (x_right - x_left) / scale[0]
59
+ ...
60
+ ... if not ((n_left is not None and data[x_left] > data[n_left])
61
+ ... or (n_right is not None and data[x_right] > data[n_right])):
62
+ ... return loss * 100
63
+ ...
64
+ ... return loss
65
+ """
66
+ def _wrapped (loss_per_interval ):
67
+ loss_per_interval .nth_neighbors = n
68
+ return loss_per_interval
69
+ return _wrapped
70
+
71
+
72
+ @uses_nth_neighbors (0 )
18
73
def uniform_loss (interval , scale , data , neighbors ):
19
74
"""Loss function that samples the domain uniformly.
20
75
@@ -36,6 +91,7 @@ def uniform_loss(interval, scale, data, neighbors):
36
91
return dx
37
92
38
93
94
+ @uses_nth_neighbors (0 )
39
95
def default_loss (interval , scale , data , neighbors ):
40
96
"""Calculate loss on a single interval.
41
97
@@ -70,6 +126,7 @@ def _loss_of_multi_interval(xs, ys):
70
126
return sum (vol (pts [i :i + 3 ]) for i in range (N )) / N
71
127
72
128
129
+ @uses_nth_neighbors (1 )
73
130
def triangle_loss (interval , scale , data , neighbors ):
74
131
x_left , x_right = interval
75
132
xs = [neighbors [x_left ][0 ], x_left , x_right , neighbors [x_right ][1 ]]
@@ -85,6 +142,7 @@ def triangle_loss(interval, scale, data, neighbors):
85
142
86
143
87
144
def get_curvature_loss (area_factor = 1 , euclid_factor = 0.02 , horizontal_factor = 0.02 ):
145
+ @uses_nth_neighbors (1 )
88
146
def curvature_loss (interval , scale , data , neighbors ):
89
147
triangle_loss_ = triangle_loss (interval , scale , data , neighbors )
90
148
default_loss_ = default_loss (interval , scale , data , neighbors )
@@ -118,8 +176,8 @@ def _get_neighbors_from_list(xs):
118
176
return sortedcontainers .SortedDict (neighbors )
119
177
120
178
121
- def _get_intervals (x , neighbors , nn_neighbors ):
122
- nn = nn_neighbors
179
+ def _get_intervals (x , neighbors , nth_neighbors ):
180
+ nn = nth_neighbors
123
181
i = neighbors .index (x )
124
182
start = max (0 , i - nn - 1 )
125
183
end = min (len (neighbors ), i + nn + 2 )
@@ -141,10 +199,6 @@ class Learner1D(BaseLearner):
141
199
A function that returns the loss for a single interval of the domain.
142
200
If not provided, then a default is used, which uses the scaled distance
143
201
in the x-y plane as the loss. See the notes for more details.
144
- nn_neighbors : int, default: 0
145
- The number of neighboring intervals that the loss function
146
- takes into account. If ``loss_per_interval`` doesn't use the neighbors
147
- at all, then it should be 0.
148
202
149
203
Attributes
150
204
----------
@@ -170,16 +224,25 @@ class Learner1D(BaseLearner):
170
224
A map containing points as keys to its neighbors as a tuple.
171
225
At the left ``x_left`` and right ``x_left`` most boundary it has
172
226
``x_left: (None, float)`` and ``x_right: (float, None)``.
227
+
228
+ The `loss_per_interval` function should also have
229
+ an attribute `nth_neighbors` that indicates how many of the neighboring
230
+ intervals to `interval` are used. If `loss_per_interval` doesn't
231
+ have such an attribute, it's assumed that is uses **no** neighboring
232
+ intervals. Also see the `uses_nth_neighbors` decorator.
233
+ **WARNING**: When modifying the `data` and `neighbors` datastructures
234
+ the learner will behave in an undefined way.
173
235
"""
174
236
175
- def __init__ (self , function , bounds , loss_per_interval = None , nn_neighbors = 0 ):
237
+ def __init__ (self , function , bounds , loss_per_interval = None ):
176
238
self .function = function
177
- self .nn_neighbors = nn_neighbors
178
239
179
- if nn_neighbors == 0 :
180
- self .loss_per_interval = loss_per_interval or default_loss
240
+ if hasattr ( loss_per_interval , 'nth_neighbors' ) :
241
+ self .nth_neighbors = loss_per_interval . nth_neighbors
181
242
else :
182
- self .loss_per_interval = loss_per_interval or get_curvature_loss ()
243
+ self .nth_neighbors = 0
244
+
245
+ self .loss_per_interval = loss_per_interval or default_loss
183
246
184
247
# A dict storing the loss function for each interval x_n.
185
248
self .losses = {}
@@ -278,10 +341,10 @@ def _update_losses(self, x, real=True):
278
341
279
342
if real :
280
343
# We need to update all interpolated losses in the interval
281
- # (x_left, x), (x, x_right) and the nn_neighbors nearest
344
+ # (x_left, x), (x, x_right) and the nth_neighbors nearest
282
345
# neighboring intervals. Since the addition of the
283
346
# point 'x' could change their loss.
284
- for ival in _get_intervals (x , self .neighbors , self .nn_neighbors ):
347
+ for ival in _get_intervals (x , self .neighbors , self .nth_neighbors ):
285
348
self ._update_interpolated_loss_in_interval (* ival )
286
349
287
350
# Since 'x' is in between (x_left, x_right),
0 commit comments