Skip to content

Commit 4bf9aff

Browse files
committed
Learner1D: return inf loss when the bounds aren't done
1 parent a94ecd0 commit 4bf9aff

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

adaptive/learner/learner1D.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import collections.abc
22
import itertools
33
import math
4-
from copy import deepcopy
4+
from copy import copy, deepcopy
55
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
66

77
import cloudpickle
@@ -290,6 +290,7 @@ def __init__(
290290
self._dx_eps = 2 * max(np.abs(bounds)) * np.finfo(float).eps
291291

292292
self.bounds = list(bounds)
293+
self.__missing_bounds = set(self.bounds) # cache of missing bounds
293294

294295
self._vdim: Optional[int] = None
295296

@@ -325,6 +326,8 @@ def npoints(self) -> int:
325326

326327
@cache_latest
327328
def loss(self, real: bool = True) -> float:
329+
if self.__missing_bounds:
330+
return np.inf
328331
losses = self.losses if real else self.losses_combined
329332
if not losses:
330333
return np.inf
@@ -604,6 +607,15 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[float], List[floa
604607

605608
return points, loss_improvements
606609

610+
def _missing_bounds(self) -> List[Real]:
611+
missing_bounds = []
612+
for b in copy(self.__missing_bounds):
613+
if b in self.data:
614+
self.__missing_bounds.remove(b)
615+
elif b not in self.pending_points:
616+
missing_bounds.append(b)
617+
return missing_bounds
618+
607619
def _ask_points_without_adding(self, n: int) -> Tuple[List[float], List[float]]:
608620
"""Return 'n' points that are expected to maximally reduce the loss.
609621
Without altering the state of the learner"""
@@ -619,12 +631,7 @@ def _ask_points_without_adding(self, n: int) -> Tuple[List[float], List[float]]:
619631
return [], []
620632

621633
# If the bounds have not been chosen yet, we choose them first.
622-
missing_bounds = [
623-
b
624-
for b in self.bounds
625-
if b not in self.data and b not in self.pending_points
626-
]
627-
634+
missing_bounds = self._missing_bounds()
628635
if len(missing_bounds) >= n:
629636
return missing_bounds[:n], [np.inf] * n
630637

0 commit comments

Comments
 (0)