3
3
import heapq
4
4
import itertools
5
5
import math
6
+ from collections import Iterable
6
7
7
8
import numpy as np
8
9
import sortedcontainers
9
10
10
11
from .base_learner import BaseLearner
12
+ from .learnerND import volume
13
+ from .triangulation import simplex_volume_in_embedding
11
14
from ..notebook_integration import ensure_holoviews
12
15
from ..utils import cache_latest
13
16
@@ -56,6 +59,45 @@ def default_loss(interval, scale, function_values):
56
59
return loss
57
60
58
61
62
+ def _loss_of_multi_interval (xs , ys ):
63
+ N = len (xs ) - 2
64
+ if isinstance (ys [0 ], Iterable ):
65
+ pts = [(x , * y ) for x , y in zip (xs , ys )]
66
+ vol = simplex_volume_in_embedding
67
+ else :
68
+ pts = [(x , y ) for x , y in zip (xs , ys )]
69
+ vol = volume
70
+ return sum (vol (pts [i :i + 3 ]) for i in range (N )) / N
71
+
72
+
73
+ def triangle_loss (interval , neighbours , scale , function_values ):
74
+ x_left , x_right = interval
75
+ neighbour_left , neighbour_right = neighbours
76
+ xs = [neighbour_left , x_left , x_right , neighbour_right ]
77
+ # The neighbours could be None if we are at the boundary, in that case we
78
+ # have to filter this out
79
+ xs = [x for x in xs if x is not None ]
80
+
81
+ if len (xs ) <= 2 :
82
+ return (x_right - x_left ) / scale [0 ]
83
+ else :
84
+ y_scale = scale [1 ] or 1
85
+ ys_scaled = [function_values [x ] / y_scale for x in xs ]
86
+ xs_scaled = [x / scale [0 ] for x in xs ]
87
+ return _loss_of_multi_interval (xs_scaled , ys_scaled )
88
+
89
+
90
+ def get_curvature_loss (area_factor = 1 , euclid_factor = 0.02 , horizontal_factor = 0.02 ):
91
+ def curvature_loss (interval , neighbours , scale , function_values ):
92
+ triangle_loss_ = triangle_loss (interval , neighbours , scale , function_values )
93
+ default_loss_ = default_loss (interval , scale , function_values )
94
+ dx = (interval [1 ] - interval [0 ]) / scale [0 ]
95
+ return (area_factor * (triangle_loss_ ** 0.5 )
96
+ + euclid_factor * default_loss_
97
+ + horizontal_factor * dx )
98
+ return curvature_loss
99
+
100
+
59
101
def linspace (x_left , x_right , n ):
60
102
"""This is equivalent to
61
103
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
@@ -116,9 +158,14 @@ class Learner1D(BaseLearner):
116
158
to have values for both of the points in 'interval'.
117
159
"""
118
160
119
- def __init__ (self , function , bounds , loss_per_interval = None ):
161
+ def __init__ (self , function , bounds , loss_per_interval = None , loss_depends_on_neighbours = False ):
120
162
self .function = function
121
- self .loss_per_interval = loss_per_interval or default_loss
163
+ self ._loss_depends_on_neighbours = loss_depends_on_neighbours
164
+
165
+ if loss_depends_on_neighbours :
166
+ self .loss_per_interval = loss_per_interval or get_curvature_loss ()
167
+ else :
168
+ self .loss_per_interval = loss_per_interval or default_loss
122
169
123
170
# A dict storing the loss function for each interval x_n.
124
171
self .losses = {}
@@ -176,25 +223,42 @@ def loss(self, real=True):
176
223
losses = self .losses if real else self .losses_combined
177
224
return max (losses .values ()) if len (losses ) > 0 else float ('inf' )
178
225
226
+ def _get_loss_in_interval (self , x_left , x_right ):
227
+ assert x_left is not None and x_right is not None
228
+
229
+ if x_right - x_left < self ._dx_eps :
230
+ return 0
231
+
232
+ # we need to compute the loss for this interval
233
+ interval = (x_left , x_right )
234
+ if self ._loss_depends_on_neighbours :
235
+ neighbour_left = self .neighbors .get (x_left , (None , None ))[0 ]
236
+ neighbour_right = self .neighbors .get (x_right , (None , None ))[1 ]
237
+ neighbours = neighbour_left , neighbour_right
238
+ return self .loss_per_interval (interval , neighbours ,
239
+ self ._scale , self .data )
240
+ else :
241
+ return self .loss_per_interval (interval , self ._scale , self .data )
242
+
243
+
179
244
def _update_interpolated_loss_in_interval (self , x_left , x_right ):
180
- if x_left is not None and x_right is not None :
181
- dx = x_right - x_left
182
- if dx < self ._dx_eps :
183
- loss = 0
184
- else :
185
- loss = self .loss_per_interval ((x_left , x_right ),
186
- self ._scale , self .data )
187
- self .losses [x_left , x_right ] = loss
188
-
189
- # Iterate over all interpolated intervals in between
190
- # x_left and x_right and set the newly interpolated loss.
191
- a , b = x_left , None
192
- while b != x_right :
193
- b = self .neighbors_combined [a ][1 ]
194
- self .losses_combined [a , b ] = (b - a ) * loss / dx
195
- a = b
245
+ if x_left is None or x_right is None :
246
+ return
247
+
248
+ loss = self ._get_loss_in_interval (x_left , x_right )
249
+ self .losses [x_left , x_right ] = loss
250
+
251
+ # Iterate over all interpolated intervals in between
252
+ # x_left and x_right and set the newly interpolated loss.
253
+ a , b = x_left , None
254
+ dx = x_right - x_left
255
+ while b != x_right :
256
+ b = self .neighbors_combined [a ][1 ]
257
+ self .losses_combined [a , b ] = (b - a ) * loss / dx
258
+ a = b
196
259
197
260
def _update_losses (self , x , real = True ):
261
+ """Update all losses that depend on x"""
198
262
# When we add a new point x, we should update the losses
199
263
# (x_left, x_right) are the "real" neighbors of 'x'.
200
264
x_left , x_right = self ._find_neighbors (x , self .neighbors )
@@ -212,6 +276,13 @@ def _update_losses(self, x, real=True):
212
276
self ._update_interpolated_loss_in_interval (x_left , x )
213
277
self ._update_interpolated_loss_in_interval (x , x_right )
214
278
279
+ # if the loss depends on the neighbors we should also update those losses
280
+ if self ._loss_depends_on_neighbours :
281
+ neighbour_left = self .neighbors .get (x_left , (None , None ))[0 ]
282
+ neighbour_right = self .neighbors .get (x_right , (None , None ))[1 ]
283
+ self ._update_interpolated_loss_in_interval (neighbour_left , x_left )
284
+ self ._update_interpolated_loss_in_interval (x_right , neighbour_right )
285
+
215
286
# Since 'x' is in between (x_left, x_right),
216
287
# we get rid of the interval.
217
288
self .losses .pop ((x_left , x_right ), None )
@@ -358,7 +429,7 @@ def tell_many(self, xs, ys, *, force=False):
358
429
self .losses = {}
359
430
for x_left , x_right in intervals :
360
431
self .losses [x_left , x_right ] = (
361
- self .loss_per_interval (( x_left , x_right ), self . _scale , self . data )
432
+ self ._get_loss_in_interval ( x_left , x_right )
362
433
if x_right - x_left >= self ._dx_eps else 0 )
363
434
364
435
# List with "real" intervals that have interpolated intervals inside
0 commit comments