3
3
import heapq
4
4
import itertools
5
5
import math
6
+ from collections import Iterable
6
7
7
8
import numpy as np
8
9
import sortedcontainers
9
10
10
11
from .base_learner import BaseLearner
12
+ from .learnerND import volume
13
+ from .triangulation import simplex_volume_in_embedding
11
14
from ..notebook_integration import ensure_holoviews
12
15
from ..utils import cache_latest
13
16
14
17
15
- def uniform_loss (interval , scale , function_values ):
18
+ def uses_nth_neighbors (n ):
19
+ """Decorator to specify how many neighboring intervals the loss function uses.
20
+
21
+ Wraps loss functions to indicate that they expect intervals together
22
+ with ``n`` nearest neighbors
23
+
24
+ The loss function will then receive the data of the N nearest neighbors
25
+ (``nth_neighbors``) aling with the data of the interval itself in a dict.
26
+ The `~adaptive.Learner1D` will also make sure that the loss is updated
27
+ whenever one of the ``nth_neighbors`` changes.
28
+
29
+ Examples
30
+ --------
31
+
32
+ The next function is a part of the `curvature_loss_function` function.
33
+
34
+ >>> @uses_nth_neighbors(1)
35
+ ...def triangle_loss(xs, ys):
36
+ ... xs = [x for x in xs if x is not None]
37
+ ... ys = [y for y in ys if y is not None]
38
+ ...
39
+ ... if len(xs) == 2: # we do not have enough points for a triangle
40
+ ... return xs[1] - xs[0]
41
+ ...
42
+ ... N = len(xs) - 2 # number of constructed triangles
43
+ ... if isinstance(ys[0], Iterable):
44
+ ... pts = [(x, *y) for x, y in zip(xs, ys)]
45
+ ... vol = simplex_volume_in_embedding
46
+ ... else:
47
+ ... pts = [(x, y) for x, y in zip(xs, ys)]
48
+ ... vol = volume
49
+ ... return sum(vol(pts[i:i+3]) for i in range(N)) / N
50
+
51
+ Or you may define a loss that favours the (local) minima of a function,
52
+ assuming that you know your function will have a single float as output.
53
+
54
+ >>> @uses_nth_neighbors(1)
55
+ ... def local_minima_resolving_loss(xs, ys):
56
+ ... dx = xs[2] - xs[1] # the width of the interval of interest
57
+ ...
58
+ ... if not ((ys[0] is not None and ys[0] > ys[1])
59
+ ... or (ys[3] is not None and ys[3] > ys[2])):
60
+ ... return loss * 100
61
+ ...
62
+ ... return loss
63
+ """
64
+ def _wrapped (loss_per_interval ):
65
+ loss_per_interval .nth_neighbors = n
66
+ return loss_per_interval
67
+ return _wrapped
68
+
69
+ @uses_nth_neighbors (0 )
70
+ def uniform_loss (xs , ys ):
16
71
"""Loss function that samples the domain uniformly.
17
72
18
73
Works with `~adaptive.Learner1D` only.
@@ -27,33 +82,58 @@ def uniform_loss(interval, scale, function_values):
27
82
... loss_per_interval=uniform_sampling_1d)
28
83
>>>
29
84
"""
30
- x_left , x_right = interval
31
- x_scale , _ = scale
32
- dx = (x_right - x_left ) / x_scale
85
+ dx = xs [1 ] - xs [0 ]
33
86
return dx
34
87
35
88
36
- def default_loss (interval , scale , function_values ):
89
+ @uses_nth_neighbors (0 )
90
+ def default_loss (xs , ys ):
37
91
"""Calculate loss on a single interval.
38
92
39
93
Currently returns the rescaled length of the interval. If one of the
40
94
y-values is missing, returns 0 (so the intervals with missing data are
41
95
never touched. This behavior should be improved later.
42
96
"""
43
- x_left , x_right = interval
44
- y_right , y_left = function_values [x_right ], function_values [x_left ]
45
- x_scale , y_scale = scale
46
- dx = (x_right - x_left ) / x_scale
47
- if y_scale == 0 :
48
- loss = dx
97
+ dx = xs [1 ] - xs [0 ]
98
+ if isinstance (ys [0 ], Iterable ):
99
+ dy = [abs (a - b ) for a , b in zip (* ys )]
100
+ return np .hypot (dx , dy ).max ()
101
+ else :
102
+ dy = ys [1 ] - ys [0 ]
103
+ return np .hypot (dx , dy )
104
+
105
+
106
+ @uses_nth_neighbors (1 )
107
+ def triangle_loss (xs , ys ):
108
+ xs = [x for x in xs if x is not None ]
109
+ ys = [y for y in ys if y is not None ]
110
+
111
+ if len (xs ) == 2 : # we do not have enough points for a triangle
112
+ return xs [1 ] - xs [0 ]
113
+
114
+ N = len (xs ) - 2 # number of constructed triangles
115
+ if isinstance (ys [0 ], Iterable ):
116
+ pts = [(x , * y ) for x , y in zip (xs , ys )]
117
+ vol = simplex_volume_in_embedding
49
118
else :
50
- dy = (y_right - y_left ) / y_scale
51
- try :
52
- len (dy )
53
- loss = np .hypot (dx , dy ).max ()
54
- except TypeError :
55
- loss = math .hypot (dx , dy )
56
- return loss
119
+ pts = [(x , y ) for x , y in zip (xs , ys )]
120
+ vol = volume
121
+ return sum (vol (pts [i :i + 3 ]) for i in range (N )) / N
122
+
123
+
124
+ def curvature_loss_function (area_factor = 1 , euclid_factor = 0.02 , horizontal_factor = 0.02 ):
125
+ @uses_nth_neighbors (1 )
126
+ def curvature_loss (xs , ys ):
127
+ xs_middle = xs [1 :3 ]
128
+ ys_middle = xs [1 :3 ]
129
+
130
+ triangle_loss_ = triangle_loss (xs , ys )
131
+ default_loss_ = default_loss (xs_middle , ys_middle )
132
+ dx = xs_middle [0 ] - xs_middle [0 ]
133
+ return (area_factor * (triangle_loss_ ** 0.5 )
134
+ + euclid_factor * default_loss_
135
+ + horizontal_factor * dx )
136
+ return curvature_loss
57
137
58
138
59
139
def linspace (x_left , x_right , n ):
@@ -79,6 +159,15 @@ def _get_neighbors_from_list(xs):
79
159
return sortedcontainers .SortedDict (neighbors )
80
160
81
161
162
+ def _get_intervals (x , neighbors , nth_neighbors ):
163
+ nn = nth_neighbors
164
+ i = neighbors .index (x )
165
+ start = max (0 , i - nn - 1 )
166
+ end = min (len (neighbors ), i + nn + 2 )
167
+ points = neighbors .keys ()[start :end ]
168
+ return list (zip (points , points [1 :]))
169
+
170
+
82
171
class Learner1D (BaseLearner ):
83
172
"""Learns and predicts a function 'f:ℝ → ℝ^N'.
84
173
@@ -103,21 +192,34 @@ class Learner1D(BaseLearner):
103
192
104
193
Notes
105
194
-----
106
- `loss_per_interval` takes 3 parameters: ``interval``, ``scale``, and
107
- ``function_values``, and returns a scalar; the loss over the interval.
108
-
109
- interval : (float, float)
110
- The bounds of the interval.
111
- scale : (float, float)
112
- The x and y scale over all the intervals, useful for rescaling the
113
- interval loss.
114
- function_values : dict(float → float)
115
- A map containing evaluated function values. It is guaranteed
116
- to have values for both of the points in 'interval'.
195
+ `loss_per_interval` takes 2 parameters: ``xs`` and ``ys``, and returns a
196
+ scalar; the loss over the interval.
197
+ xs : tuple of floats
198
+ The x values of the interval, if `nth_neighbors` is greater than zero it
199
+ also contains the x-values of the neighbors of the interval, in ascending
200
+ order. The interval we want to know the loss of is then the middle
201
+ interval. If no neighbor is available (at the edges of the domain) then
202
+ `None` will take the place of the x-value of the neighbor.
203
+ ys : tuple of function values
204
+ The output values of the function when evaluated at the `xs`. This is
205
+ either a float or a tuple of floats in the case of vector output.
206
+
207
+
208
+ The `loss_per_interval` function may also have an attribute `nth_neighbors`
209
+ that indicates how many of the neighboring intervals to `interval` are used.
210
+ If `loss_per_interval` doesn't have such an attribute, it's assumed that is
211
+ uses **no** neighboring intervals. Also see the `uses_nth_neighbors`
212
+ decorator for more information.
117
213
"""
118
214
119
215
def __init__ (self , function , bounds , loss_per_interval = None ):
120
216
self .function = function
217
+
218
+ if hasattr (loss_per_interval , 'nth_neighbors' ):
219
+ self .nth_neighbors = loss_per_interval .nth_neighbors
220
+ else :
221
+ self .nth_neighbors = 0
222
+
121
223
self .loss_per_interval = loss_per_interval or default_loss
122
224
123
225
# A dict storing the loss function for each interval x_n.
@@ -176,25 +278,60 @@ def loss(self, real=True):
176
278
losses = self .losses if real else self .losses_combined
177
279
return max (losses .values ()) if len (losses ) > 0 else float ('inf' )
178
280
281
+ def _scale_x (self , x ):
282
+ if x is None :
283
+ return None
284
+ return x / self ._scale [0 ]
285
+
286
+ def _scale_y (self , y ):
287
+ if y is None :
288
+ return None
289
+ y_scale = self ._scale [1 ] or 1
290
+ return y / y_scale
291
+
292
+ def _get_point_by_index (self , ind ):
293
+ if ind < 0 or ind >= len (self .neighbors ):
294
+ return None
295
+ return self .neighbors .keys ()[ind ]
296
+
297
+ def _get_loss_in_interval (self , x_left , x_right ):
298
+ assert x_left is not None and x_right is not None
299
+
300
+ if x_right - x_left < self ._dx_eps :
301
+ return 0
302
+
303
+ nn = self .nth_neighbors
304
+ i = self .neighbors .index (x_left )
305
+ start = i - nn
306
+ end = i + nn + 2
307
+
308
+ xs = [self ._get_point_by_index (i ) for i in range (start , end )]
309
+ ys = [self .data .get (x , None ) for x in xs ]
310
+
311
+ xs_scaled = tuple (self ._scale_x (x ) for x in xs )
312
+ ys_scaled = tuple (self ._scale_y (y ) for y in ys )
313
+
314
+ # we need to compute the loss for this interval
315
+ return self .loss_per_interval (xs_scaled , ys_scaled )
316
+
179
317
def _update_interpolated_loss_in_interval (self , x_left , x_right ):
180
- if x_left is not None and x_right is not None :
181
- dx = x_right - x_left
182
- if dx < self ._dx_eps :
183
- loss = 0
184
- else :
185
- loss = self .loss_per_interval ((x_left , x_right ),
186
- self ._scale , self .data )
187
- self .losses [x_left , x_right ] = loss
188
-
189
- # Iterate over all interpolated intervals in between
190
- # x_left and x_right and set the newly interpolated loss.
191
- a , b = x_left , None
192
- while b != x_right :
193
- b = self .neighbors_combined [a ][1 ]
194
- self .losses_combined [a , b ] = (b - a ) * loss / dx
195
- a = b
318
+ if x_left is None or x_right is None :
319
+ return
320
+
321
+ loss = self ._get_loss_in_interval (x_left , x_right )
322
+ self .losses [x_left , x_right ] = loss
323
+
324
+ # Iterate over all interpolated intervals in between
325
+ # x_left and x_right and set the newly interpolated loss.
326
+ a , b = x_left , None
327
+ dx = x_right - x_left
328
+ while b != x_right :
329
+ b = self .neighbors_combined [a ][1 ]
330
+ self .losses_combined [a , b ] = (b - a ) * loss / dx
331
+ a = b
196
332
197
333
def _update_losses (self , x , real = True ):
334
+ """Update all losses that depend on x"""
198
335
# When we add a new point x, we should update the losses
199
336
# (x_left, x_right) are the "real" neighbors of 'x'.
200
337
x_left , x_right = self ._find_neighbors (x , self .neighbors )
@@ -207,10 +344,11 @@ def _update_losses(self, x, real=True):
207
344
208
345
if real :
209
346
# We need to update all interpolated losses in the interval
210
- # (x_left, x) and (x, x_right). Since the addition of the point
211
- # 'x' could change their loss.
212
- self ._update_interpolated_loss_in_interval (x_left , x )
213
- self ._update_interpolated_loss_in_interval (x , x_right )
347
+ # (x_left, x), (x, x_right) and the nth_neighbors nearest
348
+ # neighboring intervals. Since the addition of the
349
+ # point 'x' could change their loss.
350
+ for ival in _get_intervals (x , self .neighbors , self .nth_neighbors ):
351
+ self ._update_interpolated_loss_in_interval (* ival )
214
352
215
353
# Since 'x' is in between (x_left, x_right),
216
354
# we get rid of the interval.
@@ -284,6 +422,9 @@ def tell(self, x, y):
284
422
if x in self .data :
285
423
# The point is already evaluated before
286
424
return
425
+ if y is None :
426
+ raise TypeError ("Y-value may not be None, use learner.tell_pending(x)"
427
+ "to indicate that this value is currently being calculated" )
287
428
288
429
# either it is a float/int, if not, try casting to a np.array
289
430
if not isinstance (y , (float , int )):
@@ -356,10 +497,8 @@ def tell_many(self, xs, ys, *, force=False):
356
497
357
498
# The the losses for the "real" intervals.
358
499
self .losses = {}
359
- for x_left , x_right in intervals :
360
- self .losses [x_left , x_right ] = (
361
- self .loss_per_interval ((x_left , x_right ), self ._scale , self .data )
362
- if x_right - x_left >= self ._dx_eps else 0 )
500
+ for ival in intervals :
501
+ self .losses [ival ] = self ._get_loss_in_interval (* ival )
363
502
364
503
# List with "real" intervals that have interpolated intervals inside
365
504
to_interpolate = []
0 commit comments