12
12
import scipy .spatial
13
13
from sortedcontainers import SortedKeyList
14
14
15
- from adaptive .learner .base_learner import BaseLearner
15
+ from adaptive .learner .base_learner import BaseLearner , uses_nth_neighbors
16
16
from adaptive .notebook_integration import ensure_holoviews , ensure_plotly
17
17
from adaptive .learner .triangulation import (
18
18
Triangulation , point_in_simplex , circumsphere ,
19
19
simplex_volume_in_embedding , fast_det )
20
20
from adaptive .utils import restore , cache_latest
21
21
22
22
23
+ def to_list (inp ):
24
+ if isinstance (inp , Iterable ):
25
+ return list (inp )
26
+ return [inp ]
27
+
28
+
23
29
def volume (simplex , ys = None ):
24
30
# Notice the parameter ys is there so you can use this volume method as
25
31
# as loss function
@@ -60,6 +66,71 @@ def default_loss(simplex, ys):
60
66
return simplex_volume_in_embedding (pts )
61
67
62
68
69
+ @uses_nth_neighbors (1 )
70
+ def triangle_loss (simplex , values , neighbors , neighbor_values ):
71
+ """
72
+ Computes the average of the volumes of the simplex combined with each
73
+ neighbouring point.
74
+
75
+ Parameters
76
+ ----------
77
+ simplex : list of tuples
78
+ Each entry is one point of the simplex.
79
+ values : list of values
80
+ The function values of each of the simplex points.
81
+ neighbors : list of tuples
82
+ The neighboring points of the simplex, ordered such that simplex[0]
83
+ exacly opposes neighbors[0], etc.
84
+ neighbor_values : list of values
85
+ The function values for each of the neighboring points.
86
+
87
+ Returns
88
+ -------
89
+ loss : float
90
+ """
91
+
92
+ neighbors = [n for n in neighbors if n is not None ]
93
+ neighbor_values = [v for v in neighbor_values if v is not None ]
94
+ if len (neighbors ) == 0 :
95
+ return 0
96
+
97
+ s = [(* x , * to_list (y )) for x , y in zip (simplex , values )]
98
+ n = [(* x , * to_list (y )) for x , y in zip (neighbors , neighbor_values )]
99
+
100
+ return sum (simplex_volume_in_embedding ([* s , neighbor ])
101
+ for neighbor in n ) / len (neighbors )
102
+
103
+
104
+ def curvature_loss_function (exploration = 0.05 ):
105
+ # XXX: add doc-string!
106
+ @uses_nth_neighbors (1 )
107
+ def curvature_loss (simplex , values , neighbors , neighbor_values ):
108
+ """Compute the curvature loss of a simplex.
109
+
110
+ Parameters
111
+ ----------
112
+ simplex : list of tuples
113
+ Each entry is one point of the simplex.
114
+ values : list of values
115
+ The function values of each of the simplex points.
116
+ neighbors : list of tuples
117
+ The neighboring points of the simplex, ordered such that simplex[0]
118
+ exacly opposes neighbors[0], etc.
119
+ neighbor_values : list of values
120
+ The function values for each of the neighboring points.
121
+
122
+ Returns
123
+ -------
124
+ loss : float
125
+ """
126
+ dim = len (simplex [0 ]) # the number of coordinates
127
+ loss_input_volume = volume (simplex )
128
+
129
+ loss_curvature = triangle_loss (simplex , values , neighbors , neighbor_values )
130
+ return (loss_curvature + exploration * loss_input_volume ** ((2 + dim ) / dim )) ** (1 / (2 + dim ))
131
+ return curvature_loss
132
+
133
+
63
134
def choose_point_in_simplex (simplex , transform = None ):
64
135
"""Choose a new point in inside a simplex.
65
136
@@ -70,9 +141,10 @@ def choose_point_in_simplex(simplex, transform=None):
70
141
Parameters
71
142
----------
72
143
simplex : numpy array
73
- The coordinates of a triangle with shape (N+1, N)
144
+ The coordinates of a triangle with shape (N+1, N).
74
145
transform : N*N matrix
75
- The multiplication to apply to the simplex before choosing the new point
146
+ The multiplication to apply to the simplex before choosing
147
+ the new point.
76
148
77
149
Returns
78
150
-------
@@ -164,6 +236,17 @@ class LearnerND(BaseLearner):
164
236
def __init__ (self , func , bounds , loss_per_simplex = None ):
165
237
self ._vdim = None
166
238
self .loss_per_simplex = loss_per_simplex or default_loss
239
+
240
+ if hasattr (self .loss_per_simplex , 'nth_neighbors' ):
241
+ if self .loss_per_simplex .nth_neighbors > 1 :
242
+ raise NotImplementedError ('The provided loss function wants '
243
+ 'next-nearest neighboring simplices for the loss computation, '
244
+ 'this feature is not yet implemented, either use '
245
+ 'nth_neightbors = 0 or 1' )
246
+ self .nth_neighbors = self .loss_per_simplex .nth_neighbors
247
+ else :
248
+ self .nth_neighbors = 0
249
+
167
250
self .data = OrderedDict ()
168
251
self .pending_points = set ()
169
252
@@ -252,14 +335,15 @@ def tri(self):
252
335
253
336
try :
254
337
self ._tri = Triangulation (self .points )
255
- self ._update_losses (set (), self ._tri .simplices )
256
- return self ._tri
257
338
except ValueError :
258
339
# A ValueError is raised if we do not have enough points or
259
340
# the provided points are coplanar, so we need more points to
260
341
# create a valid triangulation
261
342
return None
262
343
344
+ self ._update_losses (set (), self ._tri .simplices )
345
+ return self ._tri
346
+
263
347
@property
264
348
def values (self ):
265
349
"""Get the values from `data` as a numpy array."""
@@ -326,10 +410,10 @@ def tell_pending(self, point, *, simplex=None):
326
410
327
411
simplex = tuple (simplex )
328
412
simplices = [self .tri .vertex_to_simplices [i ] for i in simplex ]
329
- neighbours = set .union (* simplices )
413
+ neighbors = set .union (* simplices )
330
414
# Neighbours also includes the simplex itself
331
415
332
- for simpl in neighbours :
416
+ for simpl in neighbors :
333
417
_ , to_add = self ._try_adding_pending_point_to_simplex (point , simpl )
334
418
if to_add is None :
335
419
continue
@@ -394,6 +478,7 @@ def _pop_highest_existing_simplex(self):
394
478
# find the simplex with the highest loss, we do need to check that the
395
479
# simplex hasn't been deleted yet
396
480
while len (self ._simplex_queue ):
481
+ # XXX: Need to add check that the loss is the most recent computed loss
397
482
loss , simplex , subsimplex = self ._simplex_queue .pop (0 )
398
483
if (subsimplex is None
399
484
and simplex in self .tri .simplices
@@ -449,6 +534,35 @@ def _ask(self):
449
534
450
535
return self ._ask_best_point () # O(log N)
451
536
537
+ def _compute_loss (self , simplex ):
538
+ # get the loss
539
+ vertices = self .tri .get_vertices (simplex )
540
+ values = [self .data [tuple (v )] for v in vertices ]
541
+
542
+ # scale them to a cube with sides 1
543
+ vertices = vertices @ self ._transform
544
+ values = self ._output_multiplier * np .array (values )
545
+
546
+ if self .nth_neighbors == 0 :
547
+ # compute the loss on the scaled simplex
548
+ return float (self .loss_per_simplex (vertices , values ))
549
+
550
+ # We do need the neighbors
551
+ neighbors = self .tri .get_opposing_vertices (simplex )
552
+
553
+ neighbor_points = self .tri .get_vertices (neighbors )
554
+ neighbor_values = [self .data .get (x , None ) for x in neighbor_points ]
555
+
556
+ for i , point in enumerate (neighbor_points ):
557
+ if point is not None :
558
+ neighbor_points [i ] = point @ self ._transform
559
+
560
+ for i , value in enumerate (neighbor_values ):
561
+ if value is not None :
562
+ neighbor_values [i ] = self ._output_multiplier * value
563
+
564
+ return float (self .loss_per_simplex (vertices , values , neighbor_points , neighbor_values ))
565
+
452
566
def _update_losses (self , to_delete : set , to_add : set ):
453
567
# XXX: add the points outside the triangulation to this as well
454
568
pending_points_unbound = set ()
@@ -461,7 +575,6 @@ def _update_losses(self, to_delete: set, to_add: set):
461
575
462
576
pending_points_unbound = set (p for p in pending_points_unbound
463
577
if p not in self .data )
464
-
465
578
for simplex in to_add :
466
579
loss = self ._compute_loss (simplex )
467
580
self ._losses [simplex ] = loss
@@ -476,17 +589,20 @@ def _update_losses(self, to_delete: set, to_add: set):
476
589
self ._update_subsimplex_losses (
477
590
simplex , self ._subtriangulations [simplex ].simplices )
478
591
479
- def _compute_loss (self , simplex ):
480
- # get the loss
481
- vertices = self .tri .get_vertices (simplex )
482
- values = [self .data [tuple (v )] for v in vertices ]
592
+ if self .nth_neighbors :
593
+ points_of_added_simplices = set .union (* [set (s ) for s in to_add ])
594
+ neighbors = self .tri .get_simplices_attached_to_points (
595
+ points_of_added_simplices ) - to_add
596
+ for simplex in neighbors :
597
+ loss = self ._compute_loss (simplex )
598
+ self ._losses [simplex ] = loss
483
599
484
- # scale them to a cube with sides 1
485
- vertices = vertices @ self ._transform
486
- values = self . _output_multiplier * np . array ( values )
600
+ if simplex not in self . _subtriangulations :
601
+ self ._simplex_queue . add (( loss , simplex , None ))
602
+ continue
487
603
488
- # compute the loss on the scaled simplex
489
- return float ( self . loss_per_simplex ( vertices , values ) )
604
+ self . _update_subsimplex_losses (
605
+ simplex , self . _subtriangulations [ simplex ]. simplices )
490
606
491
607
def _recompute_all_losses (self ):
492
608
"""Recompute all losses and pending losses."""
0 commit comments