|
11 | 11 | import scipy.spatial
|
12 | 12 |
|
13 | 13 | from adaptive.learner.base_learner import BaseLearner
|
| 14 | +from adaptive.notebook_integration import ensure_holoviews, ensure_plotly |
14 | 15 | from adaptive.learner.triangulation import (
|
15 | 16 | Triangulation, point_in_simplex, circumsphere,
|
16 |
| - simplex_volume_in_embedding, fast_det |
17 |
| -) |
18 |
| -from adaptive.notebook_integration import ensure_holoviews, ensure_plotly |
| 17 | + simplex_volume_in_embedding, fast_det) |
19 | 18 | from adaptive.utils import restore, cache_latest
|
20 | 19 |
|
21 | 20 |
|
@@ -178,8 +177,14 @@ def __init__(self, func, bounds, loss_per_simplex=None):
|
178 | 177 | # triangulation of the pending points inside a specific simplex
|
179 | 178 | self._subtriangulations = dict() # simplex → triangulation
|
180 | 179 |
|
181 |
| - # scale to unit |
| 180 | + # scale to unit hypercube |
| 181 | + # for the input |
182 | 182 | self._transform = np.linalg.inv(np.diag(np.diff(self._bbox).flat))
|
| 183 | + # for the output |
| 184 | + self._min_value = None |
| 185 | + self._max_value = None |
| 186 | + self._output_multiplier = 1 # If we do not know anything, do not scale the values |
| 187 | + self._recompute_losses_factor = 1.1 |
183 | 188 |
|
184 | 189 | # create a private random number generator with fixed seed
|
185 | 190 | self._random = random.Random(1)
|
@@ -271,6 +276,7 @@ def tell(self, point, value):
|
271 | 276 | if not self.inside_bounds(point):
|
272 | 277 | return
|
273 | 278 |
|
| 279 | + self._update_range(value) |
274 | 280 | if tri is not None:
|
275 | 281 | simplex = self._pending_to_simplex.get(point)
|
276 | 282 | if simplex is not None and not self._simplex_exists(simplex):
|
@@ -338,6 +344,7 @@ def _update_subsimplex_losses(self, simplex, new_subsimplices):
|
338 | 344 | subtriangulation = self._subtriangulations[simplex]
|
339 | 345 | for subsimplex in new_subsimplices:
|
340 | 346 | subloss = subtriangulation.volume(subsimplex) * loss_density
|
| 347 | + subloss = round(subloss, ndigits=8) |
341 | 348 | heapq.heappush(self._simplex_queue,
|
342 | 349 | (-subloss, simplex, subsimplex))
|
343 | 350 |
|
@@ -448,21 +455,98 @@ def update_losses(self, to_delete: set, to_add: set):
|
448 | 455 | if p not in self.data)
|
449 | 456 |
|
450 | 457 | for simplex in to_add:
|
451 |
| - vertices = self.tri.get_vertices(simplex) |
452 |
| - values = [self.data[tuple(v)] for v in vertices] |
453 |
| - loss = float(self.loss_per_simplex(vertices, values)) |
454 |
| - self._losses[simplex] = float(loss) |
| 458 | + loss = self.compute_loss(simplex) |
| 459 | + self._losses[simplex] = loss |
455 | 460 |
|
456 | 461 | for p in pending_points_unbound:
|
457 | 462 | self._try_adding_pending_point_to_simplex(p, simplex)
|
458 | 463 |
|
459 | 464 | if simplex not in self._subtriangulations:
|
| 465 | + loss = round(loss, ndigits=8) |
460 | 466 | heapq.heappush(self._simplex_queue, (-loss, simplex, None))
|
461 | 467 | continue
|
462 | 468 |
|
463 | 469 | self._update_subsimplex_losses(
|
464 | 470 | simplex, self._subtriangulations[simplex].simplices)
|
465 | 471 |
|
| 472 | + def compute_loss(self, simplex): |
| 473 | + # get the loss |
| 474 | + vertices = self.tri.get_vertices(simplex) |
| 475 | + values = [self.data[tuple(v)] for v in vertices] |
| 476 | + |
| 477 | + # scale them to a cube with sides 1 |
| 478 | + vertices = vertices @ self._transform |
| 479 | + values = self._output_multiplier * values |
| 480 | + |
| 481 | + # compute the loss on the scaled simplex |
| 482 | + return float(self.loss_per_simplex(vertices, values)) |
| 483 | + |
| 484 | + def recompute_all_losses(self): |
| 485 | + """Recompute all losses and pending losses.""" |
| 486 | + # amortized O(N) complexity |
| 487 | + if self.tri is None: |
| 488 | + return |
| 489 | + |
| 490 | + # reset the _simplex_queue |
| 491 | + self._simplex_queue = [] |
| 492 | + |
| 493 | + # recompute all losses |
| 494 | + for simplex in self.tri.simplices: |
| 495 | + loss = self.compute_loss(simplex) |
| 496 | + self._losses[simplex] = loss |
| 497 | + |
| 498 | + # now distribute it around the the children if they are present |
| 499 | + if simplex not in self._subtriangulations: |
| 500 | + loss = round(loss, ndigits=8) |
| 501 | + heapq.heappush(self._simplex_queue, (-loss, simplex, None)) |
| 502 | + continue |
| 503 | + |
| 504 | + self._update_subsimplex_losses( |
| 505 | + simplex, self._subtriangulations[simplex].simplices) |
| 506 | + |
| 507 | + @property |
| 508 | + def _scale(self): |
| 509 | + # get the output scale |
| 510 | + return self._max_value - self._min_value |
| 511 | + |
| 512 | + def _update_range(self, new_output): |
| 513 | + if self._min_value is None or self._max_value is None: |
| 514 | + # this is the first point, nothing to do, just set the range |
| 515 | + self._min_value = np.array(new_output) |
| 516 | + self._max_value = np.array(new_output) |
| 517 | + self._old_scale = self._scale |
| 518 | + return False |
| 519 | + |
| 520 | + # if range in one or more directions is doubled, then update all losses |
| 521 | + self._min_value = np.minimum(self._min_value, new_output) |
| 522 | + self._max_value = np.maximum(self._max_value, new_output) |
| 523 | + |
| 524 | + scale_multiplier = 1 / self._scale |
| 525 | + if isinstance(scale_multiplier, float): |
| 526 | + scale_multiplier = np.array([scale_multiplier], dtype=float) |
| 527 | + |
| 528 | + # the maximum absolute value that is in the range. Because this is the |
| 529 | + # largest number, this also has the largest absolute numerical error. |
| 530 | + max_absolute_value_in_range = np.max(np.abs([self._min_value, self._max_value]), axis=0) |
| 531 | + # since a float has a relative error of 1e-15, the absolute error is the value * 1e-15 |
| 532 | + abs_err = 1e-15 * max_absolute_value_in_range |
| 533 | + # when scaling the floats, the error gets increased. |
| 534 | + scaled_err = abs_err * scale_multiplier |
| 535 | + |
| 536 | + allowed_numerical_error = 1e-2 |
| 537 | + |
| 538 | + # do not scale along the axis if the numerical error gets too big |
| 539 | + scale_multiplier[scaled_err > allowed_numerical_error] = 1 |
| 540 | + |
| 541 | + self._output_multiplier = scale_multiplier |
| 542 | + |
| 543 | + scale_factor = np.max(np.nan_to_num(self._scale / self._old_scale)) |
| 544 | + if scale_factor > self._recompute_losses_factor: |
| 545 | + self._old_scale = self._scale |
| 546 | + self.recompute_all_losses() |
| 547 | + return True |
| 548 | + return False |
| 549 | + |
466 | 550 | def losses(self):
|
467 | 551 | """Get the losses of each simplex in the current triangulation, as dict
|
468 | 552 |
|
|
0 commit comments