Skip to content

Commit 24ac01a

Browse files
jhoofwijkbasnijholt
authored andcommitted
LearnerND scale output values before computing loss
1 parent fa4696e commit 24ac01a

File tree

3 files changed

+100
-8
lines changed

3 files changed

+100
-8
lines changed

adaptive/learner/learnerND.py

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
import scipy.spatial
1212

1313
from adaptive.learner.base_learner import BaseLearner
14+
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
1415
from adaptive.learner.triangulation import (
1516
Triangulation, point_in_simplex, circumsphere,
16-
simplex_volume_in_embedding, fast_det
17-
)
18-
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
17+
simplex_volume_in_embedding, fast_det)
1918
from adaptive.utils import restore, cache_latest
2019

2120

@@ -178,8 +177,14 @@ def __init__(self, func, bounds, loss_per_simplex=None):
178177
# triangulation of the pending points inside a specific simplex
179178
self._subtriangulations = dict() # simplex → triangulation
180179

181-
# scale to unit
180+
# scale to unit hypercube
181+
# for the input
182182
self._transform = np.linalg.inv(np.diag(np.diff(self._bbox).flat))
183+
# for the output
184+
self._min_value = None
185+
self._max_value = None
186+
self._output_multiplier = 1 # If we do not know anything, do not scale the values
187+
self._recompute_losses_factor = 1.1
183188

184189
# create a private random number generator with fixed seed
185190
self._random = random.Random(1)
@@ -271,6 +276,7 @@ def tell(self, point, value):
271276
if not self.inside_bounds(point):
272277
return
273278

279+
self._update_range(value)
274280
if tri is not None:
275281
simplex = self._pending_to_simplex.get(point)
276282
if simplex is not None and not self._simplex_exists(simplex):
@@ -338,6 +344,7 @@ def _update_subsimplex_losses(self, simplex, new_subsimplices):
338344
subtriangulation = self._subtriangulations[simplex]
339345
for subsimplex in new_subsimplices:
340346
subloss = subtriangulation.volume(subsimplex) * loss_density
347+
subloss = round(subloss, ndigits=8)
341348
heapq.heappush(self._simplex_queue,
342349
(-subloss, simplex, subsimplex))
343350

@@ -448,21 +455,98 @@ def update_losses(self, to_delete: set, to_add: set):
448455
if p not in self.data)
449456

450457
for simplex in to_add:
451-
vertices = self.tri.get_vertices(simplex)
452-
values = [self.data[tuple(v)] for v in vertices]
453-
loss = float(self.loss_per_simplex(vertices, values))
454-
self._losses[simplex] = float(loss)
458+
loss = self.compute_loss(simplex)
459+
self._losses[simplex] = loss
455460

456461
for p in pending_points_unbound:
457462
self._try_adding_pending_point_to_simplex(p, simplex)
458463

459464
if simplex not in self._subtriangulations:
465+
loss = round(loss, ndigits=8)
460466
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
461467
continue
462468

463469
self._update_subsimplex_losses(
464470
simplex, self._subtriangulations[simplex].simplices)
465471

472+
def compute_loss(self, simplex):
473+
# get the loss
474+
vertices = self.tri.get_vertices(simplex)
475+
values = [self.data[tuple(v)] for v in vertices]
476+
477+
# scale them to a cube with sides 1
478+
vertices = vertices @ self._transform
479+
values = self._output_multiplier * values
480+
481+
# compute the loss on the scaled simplex
482+
return float(self.loss_per_simplex(vertices, values))
483+
484+
def recompute_all_losses(self):
485+
"""Recompute all losses and pending losses."""
486+
# amortized O(N) complexity
487+
if self.tri is None:
488+
return
489+
490+
# reset the _simplex_queue
491+
self._simplex_queue = []
492+
493+
# recompute all losses
494+
for simplex in self.tri.simplices:
495+
loss = self.compute_loss(simplex)
496+
self._losses[simplex] = loss
497+
498+
# now distribute it around the the children if they are present
499+
if simplex not in self._subtriangulations:
500+
loss = round(loss, ndigits=8)
501+
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
502+
continue
503+
504+
self._update_subsimplex_losses(
505+
simplex, self._subtriangulations[simplex].simplices)
506+
507+
@property
508+
def _scale(self):
509+
# get the output scale
510+
return self._max_value - self._min_value
511+
512+
def _update_range(self, new_output):
513+
if self._min_value is None or self._max_value is None:
514+
# this is the first point, nothing to do, just set the range
515+
self._min_value = np.array(new_output)
516+
self._max_value = np.array(new_output)
517+
self._old_scale = self._scale
518+
return False
519+
520+
# if range in one or more directions is doubled, then update all losses
521+
self._min_value = np.minimum(self._min_value, new_output)
522+
self._max_value = np.maximum(self._max_value, new_output)
523+
524+
scale_multiplier = 1 / self._scale
525+
if isinstance(scale_multiplier, float):
526+
scale_multiplier = np.array([scale_multiplier], dtype=float)
527+
528+
# the maximum absolute value that is in the range. Because this is the
529+
# largest number, this also has the largest absolute numerical error.
530+
max_absolute_value_in_range = np.max(np.abs([self._min_value, self._max_value]), axis=0)
531+
# since a float has a relative error of 1e-15, the absolute error is the value * 1e-15
532+
abs_err = 1e-15 * max_absolute_value_in_range
533+
# when scaling the floats, the error gets increased.
534+
scaled_err = abs_err * scale_multiplier
535+
536+
allowed_numerical_error = 1e-2
537+
538+
# do not scale along the axis if the numerical error gets too big
539+
scale_multiplier[scaled_err > allowed_numerical_error] = 1
540+
541+
self._output_multiplier = scale_multiplier
542+
543+
scale_factor = np.max(np.nan_to_num(self._scale / self._old_scale))
544+
if scale_factor > self._recompute_losses_factor:
545+
self._old_scale = self._scale
546+
self.recompute_all_losses()
547+
return True
548+
return False
549+
466550
def losses(self):
467551
"""Get the losses of each simplex in the current triangulation, as dict
468552

adaptive/tests/test_learnernd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from .test_learners import ring_of_fire, generate_random_parametrization
99

10+
1011
def test_faiure_case_LearnerND():
1112
log = [
1213
('ask', 4),
@@ -25,6 +26,7 @@ def test_faiure_case_LearnerND():
2526
learner = LearnerND(lambda *x: x, bounds=[(-1, 1), (-1, 1), (-1, 1)])
2627
replay_log(learner, log)
2728

29+
2830
def test_interior_vs_bbox_gives_same_result():
2931
f = generate_random_parametrization(ring_of_fire)
3032

adaptive/tests/test_learners.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,8 @@ def test_expected_loss_improvement_is_less_than_total_loss(learner_type, f, lear
362362

363363
# XXX: This *should* pass (https://gitlab.kwant-project.org/qt/adaptive/issues/84)
364364
# but we xfail it now, as Learner2D will be deprecated anyway
365+
# The LearnerND fails sometimes, see
366+
# https://gitlab.kwant-project.org/qt/adaptive/merge_requests/128#note_21807
365367
@run_with(Learner1D, xfail(Learner2D), xfail(LearnerND))
366368
def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner_kwargs):
367369
"""Learners behave identically under transformations that leave
@@ -384,6 +386,10 @@ def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner
384386
learner = learner_type(lambda x: yscale * f(np.array(x) / xscale),
385387
**l_kwargs)
386388

389+
if learner_type in [Learner1D, LearnerND]:
390+
learner._recompute_losses_factor = 1
391+
control._recompute_losses_factor = 1
392+
387393
npoints = random.randrange(300, 500)
388394

389395
for n in range(npoints):

0 commit comments

Comments
 (0)