@@ -227,10 +227,10 @@ def vdim(self):
227
227
def bounds_are_done (self ):
228
228
return all (p in self .data for p in self ._bounds_points )
229
229
230
- def ip (self ):
230
+ def _ip (self ):
231
231
"""A `scipy.interpolate.LinearNDInterpolator` instance
232
232
containing the learner's data."""
233
- # XXX: take our own triangulation into account when generating the ip
233
+ # XXX: take our own triangulation into account when generating the _ip
234
234
return interpolate .LinearNDInterpolator (self .points , self .values )
235
235
236
236
@property
@@ -242,7 +242,7 @@ def tri(self):
242
242
243
243
try :
244
244
self ._tri = Triangulation (self .points )
245
- self .update_losses (set (), self ._tri .simplices )
245
+ self ._update_losses (set (), self ._tri .simplices )
246
246
return self ._tri
247
247
except ValueError :
248
248
# A ValueError is raised if we do not have enough points or
@@ -283,7 +283,7 @@ def tell(self, point, value):
283
283
simplex = None
284
284
to_delete , to_add = tri .add_point (
285
285
point , simplex , transform = self ._transform )
286
- self .update_losses (to_delete , to_add )
286
+ self ._update_losses (to_delete , to_add )
287
287
288
288
def _simplex_exists (self , simplex ):
289
289
simplex = tuple (sorted (simplex ))
@@ -441,7 +441,7 @@ def _ask(self):
441
441
442
442
return self ._ask_best_point () # O(log N)
443
443
444
- def update_losses (self , to_delete : set , to_add : set ):
444
+ def _update_losses (self , to_delete : set , to_add : set ):
445
445
# XXX: add the points outside the triangulation to this as well
446
446
pending_points_unbound = set ()
447
447
@@ -455,7 +455,7 @@ def update_losses(self, to_delete: set, to_add: set):
455
455
if p not in self .data )
456
456
457
457
for simplex in to_add :
458
- loss = self .compute_loss (simplex )
458
+ loss = self ._compute_loss (simplex )
459
459
self ._losses [simplex ] = loss
460
460
461
461
for p in pending_points_unbound :
@@ -469,7 +469,7 @@ def update_losses(self, to_delete: set, to_add: set):
469
469
self ._update_subsimplex_losses (
470
470
simplex , self ._subtriangulations [simplex ].simplices )
471
471
472
- def compute_loss (self , simplex ):
472
+ def _compute_loss (self , simplex ):
473
473
# get the loss
474
474
vertices = self .tri .get_vertices (simplex )
475
475
values = [self .data [tuple (v )] for v in vertices ]
@@ -481,7 +481,7 @@ def compute_loss(self, simplex):
481
481
# compute the loss on the scaled simplex
482
482
return float (self .loss_per_simplex (vertices , values ))
483
483
484
- def recompute_all_losses (self ):
484
+ def _recompute_all_losses (self ):
485
485
"""Recompute all losses and pending losses."""
486
486
# amortized O(N) complexity
487
487
if self .tri is None :
@@ -492,7 +492,7 @@ def recompute_all_losses(self):
492
492
493
493
# recompute all losses
494
494
for simplex in self .tri .simplices :
495
- loss = self .compute_loss (simplex )
495
+ loss = self ._compute_loss (simplex )
496
496
self ._losses [simplex ] = loss
497
497
498
498
# now distribute it around the the children if they are present
@@ -543,27 +543,14 @@ def _update_range(self, new_output):
543
543
scale_factor = np .max (np .nan_to_num (self ._scale / self ._old_scale ))
544
544
if scale_factor > self ._recompute_losses_factor :
545
545
self ._old_scale = self ._scale
546
- self .recompute_all_losses ()
546
+ self ._recompute_all_losses ()
547
547
return True
548
548
return False
549
549
550
- def losses (self ):
551
- """Get the losses of each simplex in the current triangulation, as dict
552
-
553
- Returns
554
- -------
555
- losses : dict
556
- the key is a simplex, the value is the loss of this simplex
557
- """
558
- # XXX could be a property
559
- if self .tri is None :
560
- return dict ()
561
-
562
- return self ._losses
563
-
564
550
@cache_latest
565
551
def loss (self , real = True ):
566
- losses = self .losses () # XXX: compute pending loss if real == False
552
+ # XXX: compute pending loss if real == False
553
+ losses = self ._losses if self .tri is not None else dict ()
567
554
return max (losses .values ()) if losses else float ('inf' )
568
555
569
556
def remove_unfinished (self ):
@@ -607,7 +594,7 @@ def plot(self, n=None, tri_alpha=0):
607
594
xs = ys = np .linspace (0 , 1 , n )
608
595
xs = xs * (x [1 ] - x [0 ]) + x [0 ]
609
596
ys = ys * (y [1 ] - y [0 ]) + y [0 ]
610
- z = self .ip ()(xs [:, None ], ys [None , :]).squeeze ()
597
+ z = self ._ip ()(xs [:, None ], ys [None , :]).squeeze ()
611
598
612
599
im = hv .Image (np .rot90 (z ), bounds = lbrt )
613
600
@@ -656,7 +643,7 @@ def plot_slice(self, cut_mapping, n=None):
656
643
for i in range (self .ndim )]
657
644
ind = next (i for i in range (self .ndim ) if i not in cut_mapping )
658
645
x = values [ind ]
659
- y = self .ip ()(* values )
646
+ y = self ._ip ()(* values )
660
647
p = hv .Path ((x , y ))
661
648
662
649
# Plot with 5% margins such that the boundary points are visible
@@ -686,7 +673,7 @@ def plot_slice(self, cut_mapping, n=None):
686
673
lbrt = np .reshape (lbrt , (2 , 2 )).T .flatten ().tolist ()
687
674
688
675
if len (self .data ) >= 4 :
689
- z = self .ip ()(* values ).squeeze ()
676
+ z = self ._ip ()(* values ).squeeze ()
690
677
im = hv .Image (np .rot90 (z ), bounds = lbrt )
691
678
else :
692
679
im = hv .Image ([], bounds = lbrt )
0 commit comments