Skip to content

Commit 8edaf6d

Browse files
committed
male the scale always a float, like happens in the Learner1D and Leearner2D
1 parent 4387449 commit 8edaf6d

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

adaptive/learner/learnerND.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -513,32 +513,27 @@ def _recompute_all_losses(self):
513513
@property
514514
def _scale(self):
515515
# get the output scale
516-
scale = self._max_value - self._min_value
517-
if isinstance(scale, np.ndarray):
518-
scale[scale == 0] = 1
519-
elif scale == 0:
520-
scale = 1
521-
return scale
516+
return np.max(self._max_value - self._min_value)
522517

523518
def _update_range(self, new_output):
524519
if self._min_value is None or self._max_value is None:
525520
# this is the first point, nothing to do, just set the range
526521
self._min_value = np.array(new_output)
527522
self._max_value = np.array(new_output)
528-
self._old_scale = self._scale
523+
self._old_scale = self._scale or 1
529524
return False
530525

531526
# if range in one or more directions is doubled, then update all losses
532527
self._min_value = np.minimum(self._min_value, new_output)
533528
self._max_value = np.maximum(self._max_value, new_output)
534529

535-
scale_multiplier = 1 / self._scale
530+
scale_multiplier = 1 / (self._scale or 1)
536531
if isinstance(scale_multiplier, float):
537532
scale_multiplier = np.array([scale_multiplier], dtype=float)
538533

539534
# the maximum absolute value that is in the range. Because this is the
540535
# largest number, this also has the largest absolute numerical error.
541-
max_absolute_value_in_range = np.max(np.abs([self._min_value, self._max_value]), axis=0)
536+
max_absolute_value_in_range = np.max(np.abs([self._min_value, self._max_value]))
542537
# since a float has a relative error of 1e-15, the absolute error is the value * 1e-15
543538
abs_err = 1e-15 * max_absolute_value_in_range
544539
# when scaling the floats, the error gets increased.

0 commit comments

Comments
 (0)