8
8
9
9
import numpy as np
10
10
import sortedcontainers
11
+ import sortedcollections
11
12
12
13
from adaptive .learner .base_learner import BaseLearner
13
14
from adaptive .learner .learnerND import volume
@@ -225,9 +226,6 @@ def __init__(self, function, bounds, loss_per_interval=None):
225
226
226
227
self .loss_per_interval = loss_per_interval or default_loss
227
228
228
- # A dict storing the loss function for each interval x_n.
229
- self .losses = {}
230
- self .losses_combined = {}
231
229
232
230
# When the scale changes by a factor 2, the losses are
233
231
# recomputed. This is tunable such that we can test
@@ -249,6 +247,10 @@ def __init__(self, function, bounds, loss_per_interval=None):
249
247
self ._scale = [bounds [1 ] - bounds [0 ], 0 ]
250
248
self ._oldscale = deepcopy (self ._scale )
251
249
250
+ # A LossManager storing the loss function for each interval x_n.
251
+ self .losses = loss_manager (self ._scale [0 ])
252
+ self .losses_combined = loss_manager (self ._scale [0 ])
253
+
252
254
# The precision in 'x' below which we set losses to 0.
253
255
self ._dx_eps = 2 * max (np .abs (bounds )) * np .finfo (float ).eps
254
256
@@ -284,7 +286,10 @@ def npoints(self):
284
286
@cache_latest
285
287
def loss (self , real = True ):
286
288
losses = self .losses if real else self .losses_combined
287
- return max (losses .values ()) if len (losses ) > 0 else float ('inf' )
289
+ if not losses :
290
+ return np .inf
291
+ max_interval , max_loss = losses .peekitem (0 )
292
+ return max_loss
288
293
289
294
def _scale_x (self , x ):
290
295
if x is None :
@@ -454,8 +459,7 @@ def tell(self, x, y):
454
459
455
460
# If the scale has increased enough, recompute all losses.
456
461
if self ._scale [1 ] > self ._recompute_losses_factor * self ._oldscale [1 ]:
457
-
458
- for interval in self .losses :
462
+ for interval in reversed (self .losses ):
459
463
self ._update_interpolated_loss_in_interval (* interval )
460
464
461
465
self ._oldscale = deepcopy (self ._scale )
@@ -504,18 +508,18 @@ def tell_many(self, xs, ys, *, force=False):
504
508
for neighbors in (self .neighbors , self .neighbors_combined )]
505
509
506
510
# The the losses for the "real" intervals.
507
- self .losses = {}
511
+ self .losses = loss_manager ( self . _scale [ 0 ])
508
512
for ival in intervals :
509
513
self .losses [ival ] = self ._get_loss_in_interval (* ival )
510
514
511
515
# List with "real" intervals that have interpolated intervals inside
512
516
to_interpolate = []
513
517
514
- self .losses_combined = {}
518
+ self .losses_combined = loss_manager ( self . _scale [ 0 ])
515
519
for ival in intervals_combined :
516
520
# If this interval exists in 'losses' then copy it otherwise
517
521
# calculate it.
518
- if ival in self .losses :
522
+ if ival in reversed ( self .losses ) :
519
523
self .losses_combined [ival ] = self .losses [ival ]
520
524
else :
521
525
# Set all losses to inf now, later they might be udpdated if the
@@ -530,7 +534,7 @@ def tell_many(self, xs, ys, *, force=False):
530
534
to_interpolate .append ((x_left , x_right ))
531
535
532
536
for ival in to_interpolate :
533
- if ival in self .losses :
537
+ if ival in reversed ( self .losses ) :
534
538
# If this interval does not exist it should already
535
539
# have an inf loss.
536
540
self ._update_interpolated_loss_in_interval (* ival )
@@ -566,64 +570,57 @@ def _ask_points_without_adding(self, n):
566
570
if len (missing_bounds ) >= n :
567
571
return missing_bounds [:n ], [np .inf ] * n
568
572
569
- def finite_loss (loss , xs ):
570
- # If the loss is infinite we return the
571
- # distance between the two points.
572
- if math .isinf (loss ):
573
- loss = (xs [1 ] - xs [0 ]) / self ._scale [0 ]
574
-
575
- # We round the loss to 12 digits such that losses
576
- # are equal up to numerical precision will be considered
577
- # equal.
578
- return round (loss , ndigits = 12 )
579
-
580
- quals = [(- finite_loss (loss , x ), x , 1 )
581
- for x , loss in self .losses_combined .items ()]
582
-
583
573
# Add bound intervals to quals if bounds were missing.
584
574
if len (self .data ) + len (self .pending_points ) == 0 :
585
575
# We don't have any points, so return a linspace with 'n' points.
586
576
return np .linspace (* self .bounds , n ).tolist (), [np .inf ] * n
587
- elif len (missing_bounds ) > 0 :
577
+
578
+ quals = loss_manager (self ._scale [0 ])
579
+ if len (missing_bounds ) > 0 :
588
580
# There is at least one point in between the bounds.
589
581
all_points = list (self .data .keys ()) + list (self .pending_points )
590
582
intervals = [(self .bounds [0 ], min (all_points )),
591
583
(max (all_points ), self .bounds [1 ])]
592
584
for interval , bound in zip (intervals , self .bounds ):
593
585
if bound in missing_bounds :
594
- qual = (- finite_loss (np .inf , interval ), interval , 1 )
595
- quals .append (qual )
596
-
597
- # Calculate how many points belong to each interval.
598
- points , loss_improvements = self ._subdivide_quals (
599
- quals , n - len (missing_bounds ))
600
-
601
- points = missing_bounds + points
602
- loss_improvements = [np .inf ] * len (missing_bounds ) + loss_improvements
586
+ quals [(* interval , 1 )] = np .inf
603
587
604
- return points , loss_improvements
588
+ points_to_go = n - len ( missing_bounds )
605
589
606
- def _subdivide_quals (self , quals , n ):
607
590
# Calculate how many points belong to each interval.
608
- heapq .heapify (quals )
609
-
610
- for _ in range (n ):
611
- quality , x , n = quals [0 ]
612
- if abs (x [1 ] - x [0 ]) / (n + 1 ) <= self ._dx_eps :
613
- # The interval is too small and should not be subdivided.
614
- quality = np .inf
615
- # XXX: see https://gitlab.kwant-project.org/qt/adaptive/issues/104
616
- heapq .heapreplace (quals , (quality * n / (n + 1 ), x , n + 1 ))
591
+ i , i_max = 0 , len (self .losses_combined )
592
+ for _ in range (points_to_go ):
593
+ qual , loss_qual = quals .peekitem (0 ) if quals else (None , 0 )
594
+ ival , loss_ival = self .losses_combined .peekitem (i ) if i < i_max else (None , 0 )
595
+
596
+ if (qual is None
597
+ or (ival is not None
598
+ and self ._loss (self .losses_combined , ival )
599
+ >= self ._loss (quals , qual ))):
600
+ i += 1
601
+ quals [(* ival , 2 )] = loss_ival / 2
602
+ else :
603
+ quals .pop (qual , None )
604
+ * xs , n = qual
605
+ quals [(* xs , n + 1 )] = loss_qual * n / (n + 1 )
617
606
618
607
points = list (itertools .chain .from_iterable (
619
- linspace (* interval , n ) for quality , interval , n in quals ))
608
+ linspace (* ival , n ) for ( * ival , n ) in quals ))
620
609
621
610
loss_improvements = list (itertools .chain .from_iterable (
622
- itertools .repeat (- quality , n - 1 )
623
- for quality , interval , n in quals ))
611
+ itertools .repeat (quals [x0 , x1 , n ], n - 1 )
612
+ for (x0 , x1 , n ) in quals ))
613
+
614
+ # add the missing bounds
615
+ points = missing_bounds + points
616
+ loss_improvements = [np .inf ] * len (missing_bounds ) + loss_improvements
624
617
625
618
return points , loss_improvements
626
619
620
+ def _loss (self , mapping , ival ):
621
+ loss = mapping [ival ]
622
+ return finite_loss (ival , loss , self ._scale [0 ])
623
+
627
624
def plot (self ):
628
625
"""Returns a plot of the evaluated data.
629
626
@@ -658,3 +655,42 @@ def _get_data(self):
658
655
659
656
def _set_data (self , data ):
660
657
self .tell_many (* zip (* data .items ()))
658
+
659
+
660
+ def _fix_deepcopy (sorted_dict , x_scale ):
661
+ # XXX: until https://github.com/grantjenks/sortedcollections/issues/5 is fixed
662
+ import types
663
+ def __deepcopy__ (self , memo ):
664
+ items = deepcopy (list (self .items ()))
665
+ lm = loss_manager (self .x_scale )
666
+ lm .update (items )
667
+ return lm
668
+ sorted_dict .x_scale = x_scale
669
+ sorted_dict .__deepcopy__ = types .MethodType (__deepcopy__ , sorted_dict )
670
+
671
+
672
+ def loss_manager (x_scale ):
673
+ def sort_key (ival , loss ):
674
+ loss , ival = finite_loss (ival , loss , x_scale )
675
+ return - loss , ival
676
+ sorted_dict = sortedcollections .ItemSortedDict (sort_key )
677
+ _fix_deepcopy (sorted_dict , x_scale )
678
+ return sorted_dict
679
+
680
+
681
+ def finite_loss (ival , loss , x_scale ):
682
+ """Get the socalled finite_loss of an interval in order to be able to
683
+ sort intervals that have infinite loss."""
684
+ # If the loss is infinite we return the
685
+ # distance between the two points.
686
+ if math .isinf (loss ):
687
+ loss = (ival [1 ] - ival [0 ]) / x_scale
688
+ if len (ival ) == 3 :
689
+ # Used when constructing quals. Last item is
690
+ # the number of points inside the qual.
691
+ loss /= ival [2 ]
692
+
693
+ # We round the loss to 12 digits such that losses
694
+ # are equal up to numerical precision will be considered
695
+ # equal.
696
+ return round (loss , ndigits = 12 ), ival
0 commit comments