15
15
from ..utils import cache_latest
16
16
17
17
18
- def uniform_loss (interval , scale , function_values ):
18
+ def uniform_loss (interval , scale , function_values , neighbors ):
19
19
"""Loss function that samples the domain uniformly.
20
20
21
21
Works with `~adaptive.Learner1D` only.
@@ -36,7 +36,7 @@ def uniform_loss(interval, scale, function_values):
36
36
return dx
37
37
38
38
39
- def default_loss (interval , scale , function_values ):
39
+ def default_loss (interval , scale , function_values , neighbors ):
40
40
"""Calculate loss on a single interval.
41
41
42
42
Currently returns the rescaled length of the interval. If one of the
@@ -70,12 +70,9 @@ def _loss_of_multi_interval(xs, ys):
70
70
return sum (vol (pts [i :i + 3 ]) for i in range (N )) / N
71
71
72
72
73
- def triangle_loss (interval , neighbours , scale , function_values ):
73
+ def triangle_loss (interval , scale , function_values , neighbors ):
74
74
x_left , x_right = interval
75
- neighbour_left , neighbour_right = neighbours
76
- xs = [neighbour_left , x_left , x_right , neighbour_right ]
77
- # The neighbours could be None if we are at the boundary, in that case we
78
- # have to filter this out
75
+ xs = [neighbors [x_left ][0 ], x_left , x_right , neighbors [x_right ][1 ]]
79
76
xs = [x for x in xs if x is not None ]
80
77
81
78
if len (xs ) <= 2 :
@@ -88,9 +85,9 @@ def triangle_loss(interval, neighbours, scale, function_values):
88
85
89
86
90
87
def get_curvature_loss (area_factor = 1 , euclid_factor = 0.02 , horizontal_factor = 0.02 ):
91
- def curvature_loss (interval , neighbours , scale , function_values ):
92
- triangle_loss_ = triangle_loss (interval , neighbours , scale , function_values )
93
- default_loss_ = default_loss (interval , scale , function_values )
88
+ def curvature_loss (interval , scale , function_values , neighbors ):
89
+ triangle_loss_ = triangle_loss (interval , scale , function_values , neighbors )
90
+ default_loss_ = default_loss (interval , scale , function_values , neighbors )
94
91
dx = (interval [1 ] - interval [0 ]) / scale [0 ]
95
92
return (area_factor * (triangle_loss_ ** 0.5 )
96
93
+ euclid_factor * default_loss_
@@ -121,6 +118,15 @@ def _get_neighbors_from_list(xs):
121
118
return sortedcontainers .SortedDict (neighbors )
122
119
123
120
121
+ def _get_intervals (x , neighbors , nn_neighbors ):
122
+ nn = nn_neighbors
123
+ i = neighbors .index (x )
124
+ start = max (0 , i - nn - 1 )
125
+ end = min (len (neighbors ), i + nn + 2 )
126
+ points = neighbors .keys ()[start :end ]
127
+ return list (zip (points , points [1 :]))
128
+
129
+
124
130
class Learner1D (BaseLearner ):
125
131
"""Learns and predicts a function 'f:ℝ → ℝ^N'.
126
132
@@ -135,6 +141,10 @@ class Learner1D(BaseLearner):
135
141
A function that returns the loss for a single interval of the domain.
136
142
If not provided, then a default is used, which uses the scaled distance
137
143
in the x-y plane as the loss. See the notes for more details.
144
+ nn_neighbors : int, default: 0
145
+ The number of neighboring intervals that the loss function
146
+ takes into account. If ``loss_per_interval`` doesn't use the neighbors
147
+ at all, then it should be 0.
138
148
139
149
Attributes
140
150
----------
@@ -145,9 +155,9 @@ class Learner1D(BaseLearner):
145
155
146
156
Notes
147
157
-----
148
- `loss_per_interval` takes 3 parameters: ``interval``, ``scale``, and
149
- ``function_values ``, and returns a scalar; the loss over the interval.
150
-
158
+ `loss_per_interval` takes 4 parameters: ``interval``, ``scale``,
159
+ ``data ``, and ``neighbors``, and returns a scalar; the loss over
160
+ the interval.
151
161
interval : (float, float)
152
162
The bounds of the interval.
153
163
scale : (float, float)
@@ -156,16 +166,18 @@ class Learner1D(BaseLearner):
156
166
function_values : dict(float → float)
157
167
A map containing evaluated function values. It is guaranteed
158
168
to have values for both of the points in 'interval'.
169
+ neighbors : dict(float → (float, float))
170
+ A map containing points as keys to its neighbors as a tuple.
159
171
"""
160
172
161
- def __init__ (self , function , bounds , loss_per_interval = None , loss_depends_on_neighbours = False ):
173
+ def __init__ (self , function , bounds , loss_per_interval = None , nn_neighbors = 0 ):
162
174
self .function = function
163
- self ._loss_depends_on_neighbours = loss_depends_on_neighbours
175
+ self .nn_neighbors = nn_neighbors
164
176
165
- if loss_depends_on_neighbours :
166
- self .loss_per_interval = loss_per_interval or get_curvature_loss ()
167
- else :
177
+ if nn_neighbors == 0 :
168
178
self .loss_per_interval = loss_per_interval or default_loss
179
+ else :
180
+ self .loss_per_interval = loss_per_interval or get_curvature_loss ()
169
181
170
182
# A dict storing the loss function for each interval x_n.
171
183
self .losses = {}
@@ -230,15 +242,8 @@ def _get_loss_in_interval(self, x_left, x_right):
230
242
return 0
231
243
232
244
# we need to compute the loss for this interval
233
- interval = (x_left , x_right )
234
- if self ._loss_depends_on_neighbours :
235
- neighbour_left = self .neighbors .get (x_left , (None , None ))[0 ]
236
- neighbour_right = self .neighbors .get (x_right , (None , None ))[1 ]
237
- neighbours = neighbour_left , neighbour_right
238
- return self .loss_per_interval (interval , neighbours ,
239
- self ._scale , self .data )
240
- else :
241
- return self .loss_per_interval (interval , self ._scale , self .data )
245
+ return self .loss_per_interval (
246
+ (x_left , x_right ), self ._scale , self .data , self .neighbors )
242
247
243
248
244
249
def _update_interpolated_loss_in_interval (self , x_left , x_right ):
@@ -271,17 +276,11 @@ def _update_losses(self, x, real=True):
271
276
272
277
if real :
273
278
# We need to update all interpolated losses in the interval
274
- # (x_left, x) and (x, x_right). Since the addition of the point
275
- # 'x' could change their loss.
276
- self ._update_interpolated_loss_in_interval (x_left , x )
277
- self ._update_interpolated_loss_in_interval (x , x_right )
278
-
279
- # if the loss depends on the neighbors we should also update those losses
280
- if self ._loss_depends_on_neighbours :
281
- neighbour_left = self .neighbors .get (x_left , (None , None ))[0 ]
282
- neighbour_right = self .neighbors .get (x_right , (None , None ))[1 ]
283
- self ._update_interpolated_loss_in_interval (neighbour_left , x_left )
284
- self ._update_interpolated_loss_in_interval (x_right , neighbour_right )
279
+ # (x_left, x), (x, x_right) and the nn_neighbors nearest
280
+ # neighboring intervals. Since the addition of the
281
+ # point 'x' could change their loss.
282
+ for ival in _get_intervals (x , self .neighbors , self .nn_neighbors ):
283
+ self ._update_interpolated_loss_in_interval (* ival )
285
284
286
285
# Since 'x' is in between (x_left, x_right),
287
286
# we get rid of the interval.
@@ -427,10 +426,8 @@ def tell_many(self, xs, ys, *, force=False):
427
426
428
427
# The the losses for the "real" intervals.
429
428
self .losses = {}
430
- for x_left , x_right in intervals :
431
- self .losses [x_left , x_right ] = (
432
- self ._get_loss_in_interval (x_left , x_right )
433
- if x_right - x_left >= self ._dx_eps else 0 )
429
+ for ival in intervals :
430
+ self .losses [ival ] = self ._get_loss_in_interval (* ival )
434
431
435
432
# List with "real" intervals that have interpolated intervals inside
436
433
to_interpolate = []
0 commit comments