Skip to content

Commit a8cefd6

Browse files
committed
simplify _update_range
1 parent 8edaf6d commit a8cefd6

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

adaptive/learner/learnerND.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def _compute_loss(self, simplex):
483483

484484
# scale them to a cube with sides 1
485485
vertices = vertices @ self._transform
486-
values = self._output_multiplier * values
486+
values = self._output_multiplier * np.array(values)
487487

488488
# compute the loss on the scaled simplex
489489
return float(self.loss_per_simplex(vertices, values))
@@ -513,40 +513,38 @@ def _recompute_all_losses(self):
513513
@property
514514
def _scale(self):
515515
# get the output scale
516-
return np.max(self._max_value - self._min_value)
516+
return self._max_value - self._min_value
517517

518518
def _update_range(self, new_output):
519519
if self._min_value is None or self._max_value is None:
520520
# this is the first point, nothing to do, just set the range
521-
self._min_value = np.array(new_output)
522-
self._max_value = np.array(new_output)
521+
self._min_value = np.min(new_output)
522+
self._max_value = np.max(new_output)
523523
self._old_scale = self._scale or 1
524524
return False
525525

526526
# if range in one or more directions is doubled, then update all losses
527-
self._min_value = np.minimum(self._min_value, new_output)
528-
self._max_value = np.maximum(self._max_value, new_output)
527+
self._min_value = min(self._min_value, np.min(new_output))
528+
self._max_value = max(self._max_value, np.max(new_output))
529529

530530
scale_multiplier = 1 / (self._scale or 1)
531-
if isinstance(scale_multiplier, float):
532-
scale_multiplier = np.array([scale_multiplier], dtype=float)
533531

534532
# the maximum absolute value that is in the range. Because this is the
535533
# largest number, this also has the largest absolute numerical error.
536-
max_absolute_value_in_range = np.max(np.abs([self._min_value, self._max_value]))
534+
max_absolute_value_in_range = max(abs(self._min_value),
535+
abs(self._max_value))
537536
# since a float has a relative error of 1e-15, the absolute error is the value * 1e-15
538537
abs_err = 1e-15 * max_absolute_value_in_range
539538
# when scaling the floats, the error gets increased.
540539
scaled_err = abs_err * scale_multiplier
541540

542-
allowed_numerical_error = 1e-2
543-
544541
# do not scale along the axis if the numerical error gets too big
545-
scale_multiplier[scaled_err > allowed_numerical_error] = 1
542+
if scaled_err > 1e-2: # allowed_numerical_error = 1e-2
543+
scale_multiplier = 1
546544

547545
self._output_multiplier = scale_multiplier
548546

549-
scale_factor = np.max(self._scale / self._old_scale)
547+
scale_factor = self._scale / self._old_scale
550548
if scale_factor > self._recompute_losses_factor:
551549
self._old_scale = self._scale
552550
self._recompute_all_losses()

0 commit comments

Comments
 (0)