Skip to content

Commit 4381a9e

Browse files
jhoofwijkbasnijholt
authored andcommitted
Resolve "(Learner1D) improve time complexity"
1 parent c0012a9 commit 4381a9e

File tree

3 files changed

+86
-48
lines changed

3 files changed

+86
-48
lines changed

adaptive/learner/learner1D.py

Lines changed: 84 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
import sortedcontainers
11+
import sortedcollections
1112

1213
from adaptive.learner.base_learner import BaseLearner
1314
from adaptive.learner.learnerND import volume
@@ -225,9 +226,6 @@ def __init__(self, function, bounds, loss_per_interval=None):
225226

226227
self.loss_per_interval = loss_per_interval or default_loss
227228

228-
# A dict storing the loss function for each interval x_n.
229-
self.losses = {}
230-
self.losses_combined = {}
231229

232230
# When the scale changes by a factor 2, the losses are
233231
# recomputed. This is tunable such that we can test
@@ -249,6 +247,10 @@ def __init__(self, function, bounds, loss_per_interval=None):
249247
self._scale = [bounds[1] - bounds[0], 0]
250248
self._oldscale = deepcopy(self._scale)
251249

250+
# A LossManager storing the loss function for each interval x_n.
251+
self.losses = loss_manager(self._scale[0])
252+
self.losses_combined = loss_manager(self._scale[0])
253+
252254
# The precision in 'x' below which we set losses to 0.
253255
self._dx_eps = 2 * max(np.abs(bounds)) * np.finfo(float).eps
254256

@@ -284,7 +286,10 @@ def npoints(self):
284286
@cache_latest
285287
def loss(self, real=True):
286288
losses = self.losses if real else self.losses_combined
287-
return max(losses.values()) if len(losses) > 0 else float('inf')
289+
if not losses:
290+
return np.inf
291+
max_interval, max_loss = losses.peekitem(0)
292+
return max_loss
288293

289294
def _scale_x(self, x):
290295
if x is None:
@@ -454,8 +459,7 @@ def tell(self, x, y):
454459

455460
# If the scale has increased enough, recompute all losses.
456461
if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]:
457-
458-
for interval in self.losses:
462+
for interval in reversed(self.losses):
459463
self._update_interpolated_loss_in_interval(*interval)
460464

461465
self._oldscale = deepcopy(self._scale)
@@ -504,18 +508,18 @@ def tell_many(self, xs, ys, *, force=False):
504508
for neighbors in (self.neighbors, self.neighbors_combined)]
505509

506510
# The the losses for the "real" intervals.
507-
self.losses = {}
511+
self.losses = loss_manager(self._scale[0])
508512
for ival in intervals:
509513
self.losses[ival] = self._get_loss_in_interval(*ival)
510514

511515
# List with "real" intervals that have interpolated intervals inside
512516
to_interpolate = []
513517

514-
self.losses_combined = {}
518+
self.losses_combined = loss_manager(self._scale[0])
515519
for ival in intervals_combined:
516520
# If this interval exists in 'losses' then copy it otherwise
517521
# calculate it.
518-
if ival in self.losses:
522+
if ival in reversed(self.losses):
519523
self.losses_combined[ival] = self.losses[ival]
520524
else:
521525
# Set all losses to inf now, later they might be udpdated if the
@@ -530,7 +534,7 @@ def tell_many(self, xs, ys, *, force=False):
530534
to_interpolate.append((x_left, x_right))
531535

532536
for ival in to_interpolate:
533-
if ival in self.losses:
537+
if ival in reversed(self.losses):
534538
# If this interval does not exist it should already
535539
# have an inf loss.
536540
self._update_interpolated_loss_in_interval(*ival)
@@ -566,64 +570,57 @@ def _ask_points_without_adding(self, n):
566570
if len(missing_bounds) >= n:
567571
return missing_bounds[:n], [np.inf] * n
568572

569-
def finite_loss(loss, xs):
570-
# If the loss is infinite we return the
571-
# distance between the two points.
572-
if math.isinf(loss):
573-
loss = (xs[1] - xs[0]) / self._scale[0]
574-
575-
# We round the loss to 12 digits such that losses
576-
# are equal up to numerical precision will be considered
577-
# equal.
578-
return round(loss, ndigits=12)
579-
580-
quals = [(-finite_loss(loss, x), x, 1)
581-
for x, loss in self.losses_combined.items()]
582-
583573
# Add bound intervals to quals if bounds were missing.
584574
if len(self.data) + len(self.pending_points) == 0:
585575
# We don't have any points, so return a linspace with 'n' points.
586576
return np.linspace(*self.bounds, n).tolist(), [np.inf] * n
587-
elif len(missing_bounds) > 0:
577+
578+
quals = loss_manager(self._scale[0])
579+
if len(missing_bounds) > 0:
588580
# There is at least one point in between the bounds.
589581
all_points = list(self.data.keys()) + list(self.pending_points)
590582
intervals = [(self.bounds[0], min(all_points)),
591583
(max(all_points), self.bounds[1])]
592584
for interval, bound in zip(intervals, self.bounds):
593585
if bound in missing_bounds:
594-
qual = (-finite_loss(np.inf, interval), interval, 1)
595-
quals.append(qual)
596-
597-
# Calculate how many points belong to each interval.
598-
points, loss_improvements = self._subdivide_quals(
599-
quals, n - len(missing_bounds))
600-
601-
points = missing_bounds + points
602-
loss_improvements = [np.inf] * len(missing_bounds) + loss_improvements
586+
quals[(*interval, 1)] = np.inf
603587

604-
return points, loss_improvements
588+
points_to_go = n - len(missing_bounds)
605589

606-
def _subdivide_quals(self, quals, n):
607590
# Calculate how many points belong to each interval.
608-
heapq.heapify(quals)
609-
610-
for _ in range(n):
611-
quality, x, n = quals[0]
612-
if abs(x[1] - x[0]) / (n + 1) <= self._dx_eps:
613-
# The interval is too small and should not be subdivided.
614-
quality = np.inf
615-
# XXX: see https://gitlab.kwant-project.org/qt/adaptive/issues/104
616-
heapq.heapreplace(quals, (quality * n / (n + 1), x, n + 1))
591+
i, i_max = 0, len(self.losses_combined)
592+
for _ in range(points_to_go):
593+
qual, loss_qual = quals.peekitem(0) if quals else (None, 0)
594+
ival, loss_ival = self.losses_combined.peekitem(i) if i < i_max else (None, 0)
595+
596+
if (qual is None
597+
or (ival is not None
598+
and self._loss(self.losses_combined, ival)
599+
>= self._loss(quals, qual))):
600+
i += 1
601+
quals[(*ival, 2)] = loss_ival / 2
602+
else:
603+
quals.pop(qual, None)
604+
*xs, n = qual
605+
quals[(*xs, n+1)] = loss_qual * n / (n+1)
617606

618607
points = list(itertools.chain.from_iterable(
619-
linspace(*interval, n) for quality, interval, n in quals))
608+
linspace(*ival, n) for (*ival, n) in quals))
620609

621610
loss_improvements = list(itertools.chain.from_iterable(
622-
itertools.repeat(-quality, n - 1)
623-
for quality, interval, n in quals))
611+
itertools.repeat(quals[x0, x1, n], n - 1)
612+
for (x0, x1, n) in quals))
613+
614+
# add the missing bounds
615+
points = missing_bounds + points
616+
loss_improvements = [np.inf] * len(missing_bounds) + loss_improvements
624617

625618
return points, loss_improvements
626619

620+
def _loss(self, mapping, ival):
621+
loss = mapping[ival]
622+
return finite_loss(ival, loss, self._scale[0])
623+
627624
def plot(self):
628625
"""Returns a plot of the evaluated data.
629626
@@ -658,3 +655,42 @@ def _get_data(self):
658655

659656
def _set_data(self, data):
660657
self.tell_many(*zip(*data.items()))
658+
659+
660+
def _fix_deepcopy(sorted_dict, x_scale):
661+
# XXX: until https://github.com/grantjenks/sortedcollections/issues/5 is fixed
662+
import types
663+
def __deepcopy__(self, memo):
664+
items = deepcopy(list(self.items()))
665+
lm = loss_manager(self.x_scale)
666+
lm.update(items)
667+
return lm
668+
sorted_dict.x_scale = x_scale
669+
sorted_dict.__deepcopy__ = types.MethodType(__deepcopy__, sorted_dict)
670+
671+
672+
def loss_manager(x_scale):
673+
def sort_key(ival, loss):
674+
loss, ival = finite_loss(ival, loss, x_scale)
675+
return -loss, ival
676+
sorted_dict = sortedcollections.ItemSortedDict(sort_key)
677+
_fix_deepcopy(sorted_dict, x_scale)
678+
return sorted_dict
679+
680+
681+
def finite_loss(ival, loss, x_scale):
682+
"""Get the socalled finite_loss of an interval in order to be able to
683+
sort intervals that have infinite loss."""
684+
# If the loss is infinite we return the
685+
# distance between the two points.
686+
if math.isinf(loss):
687+
loss = (ival[1] - ival[0]) / x_scale
688+
if len(ival) == 3:
689+
# Used when constructing quals. Last item is
690+
# the number of points inside the qual.
691+
loss /= ival[2]
692+
693+
# We round the loss to 12 digits such that losses
694+
# are equal up to numerical precision will be considered
695+
# equal.
696+
return round(loss, ndigits=12), ival

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ channels:
66
dependencies:
77
- python=3.6
88
- sortedcontainers
9+
- sortedcollections
910
- scipy
1011
- holoviews
1112
- ipyparallel

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def get_version_and_cmdclass(package_name):
2626

2727
install_requires = [
2828
'scipy',
29+
'sortedcollections',
2930
'sortedcontainers',
3031
]
3132

0 commit comments

Comments
 (0)