Skip to content

Commit 0207eef

Browse files
committed
Merge branch 'stable-0.7'
2 parents 7c34051 + 21a836e commit 0207eef

File tree

4 files changed

+29
-15
lines changed

4 files changed

+29
-15
lines changed

adaptive/learner/learner2D.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,13 @@ def _fill_stack(self, stack_till=1):
435435
triangle = ip.tri.points[ip.tri.vertices[jsimplex]]
436436
point_new = choose_point_in_triangle(triangle, max_badness=5)
437437
point_new = tuple(self._unscale(point_new))
438+
439+
# np.clip results in numerical precision problems
440+
# https://gitlab.kwant-project.org/qt/adaptive/issues/132
441+
clip = lambda x, l, u: max(l, min(u, x))
442+
point_new = (clip(point_new[0], *self.bounds[0]),
443+
clip(point_new[1], *self.bounds[1]))
444+
438445
loss_new = losses[jsimplex]
439446

440447
points_new.append(point_new)

adaptive/learner/learnerND.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
from scipy import interpolate
1111
import scipy.spatial
12+
from sortedcontainers import SortedKeyList
1213

1314
from adaptive.learner.base_learner import BaseLearner
1415
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
@@ -91,7 +92,6 @@ def choose_point_in_simplex(simplex, transform=None):
9192
distance_matrix = scipy.spatial.distance.squareform(distances)
9293
i, j = np.unravel_index(np.argmax(distance_matrix),
9394
distance_matrix.shape)
94-
9595
point = (simplex[i, :] + simplex[j, :]) / 2
9696

9797
if transform is not None:
@@ -100,6 +100,15 @@ def choose_point_in_simplex(simplex, transform=None):
100100
return point
101101

102102

103+
def _simplex_evaluation_priority(key):
104+
# We round the loss to 8 digits such that losses
105+
# are equal up to numerical precision will be considered
106+
# to be equal. This is needed because we want the learner
107+
# to behave in a deterministic fashion.
108+
loss, simplex, subsimplex = key
109+
return -round(loss, ndigits=8), simplex, subsimplex or (0,)
110+
111+
103112
class LearnerND(BaseLearner):
104113
"""Learns and predicts a function 'f: ℝ^N → ℝ^M'.
105114
@@ -200,7 +209,7 @@ def __init__(self, func, bounds, loss_per_simplex=None):
200209
# so when popping an item, you should check that the simplex that has
201210
# been returned has not been deleted. This checking is done by
202211
# _pop_highest_existing_simplex
203-
self._simplex_queue = [] # heap
212+
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
204213

205214
@property
206215
def npoints(self):
@@ -344,9 +353,7 @@ def _update_subsimplex_losses(self, simplex, new_subsimplices):
344353
subtriangulation = self._subtriangulations[simplex]
345354
for subsimplex in new_subsimplices:
346355
subloss = subtriangulation.volume(subsimplex) * loss_density
347-
subloss = round(subloss, ndigits=8)
348-
heapq.heappush(self._simplex_queue,
349-
(-subloss, simplex, subsimplex))
356+
self._simplex_queue.add((subloss, simplex, subsimplex))
350357

351358
def _ask_and_tell_pending(self, n=1):
352359
xs, losses = zip(*(self._ask() for _ in range(n)))
@@ -386,7 +393,7 @@ def _pop_highest_existing_simplex(self):
386393
# find the simplex with the highest loss, we do need to check that the
387394
# simplex hasn't been deleted yet
388395
while len(self._simplex_queue):
389-
loss, simplex, subsimplex = heapq.heappop(self._simplex_queue)
396+
loss, simplex, subsimplex = self._simplex_queue.pop(0)
390397
if (subsimplex is None
391398
and simplex in self.tri.simplices
392399
and simplex not in self._subtriangulations):
@@ -462,8 +469,7 @@ def _update_losses(self, to_delete: set, to_add: set):
462469
self._try_adding_pending_point_to_simplex(p, simplex)
463470

464471
if simplex not in self._subtriangulations:
465-
loss = round(loss, ndigits=8)
466-
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
472+
self._simplex_queue.add((loss, simplex, None))
467473
continue
468474

469475
self._update_subsimplex_losses(
@@ -488,7 +494,7 @@ def _recompute_all_losses(self):
488494
return
489495

490496
# reset the _simplex_queue
491-
self._simplex_queue = []
497+
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
492498

493499
# recompute all losses
494500
for simplex in self.tri.simplices:
@@ -497,8 +503,7 @@ def _recompute_all_losses(self):
497503

498504
# now distribute it around the the children if they are present
499505
if simplex not in self._subtriangulations:
500-
loss = round(loss, ndigits=8)
501-
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
506+
self._simplex_queue.add((loss, simplex, None))
502507
continue
503508

504509
self._update_subsimplex_losses(

adaptive/tests/test_learners.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(learner_type, f, lear
362362

363363
# XXX: This *should* pass (https://gitlab.kwant-project.org/qt/adaptive/issues/84)
364364
# but we xfail it now, as Learner2D will be deprecated anyway
365-
# The LearnerND fails sometimes, see
366-
# https://gitlab.kwant-project.org/qt/adaptive/merge_requests/128#note_21807
367-
@run_with(Learner1D, xfail(Learner2D), xfail(LearnerND))
365+
@run_with(Learner1D, xfail(Learner2D), LearnerND)
368366
def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner_kwargs):
369367
"""Learners behave identically under transformations that leave
370368
the loss invariant.
@@ -392,6 +390,10 @@ def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner
392390

393391
npoints = random.randrange(300, 500)
394392

393+
if learner_type is LearnerND:
394+
# Because the LearnerND is slow
395+
npoints //= 10
396+
395397
for n in range(npoints):
396398
cxs, _ = control.ask(1)
397399
xs, _ = learner.ask(1)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def get_version_and_cmdclass(package_name):
2727
install_requires = [
2828
'scipy',
2929
'sortedcollections',
30-
'sortedcontainers',
30+
'sortedcontainers >= 2.0',
3131
]
3232

3333
extras_require = {

0 commit comments

Comments
 (0)