Skip to content

Commit d1655d9

Browse files
committed
more fixes for adaptive/learner/balancing_learner.py
1 parent 9a82b64 commit d1655d9

File tree

1 file changed

+16
-32
lines changed

1 file changed

+16
-32
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
from adaptive.utils import cache_latest, named_product, restore
1414

1515

16-
def dispatch(
17-
child_functions: List[Callable], arg: Any,
18-
) -> Union[int, np.float64, float]:
16+
def dispatch(child_functions: List[Callable], arg: Any,) -> Union[Any]:
1917
index, x = arg
2018
return child_functions[index](x)
2119

@@ -94,14 +92,14 @@ def __init__(
9492
self.strategy = strategy
9593

9694
@property
97-
def data(self) -> Dict[Tuple[int, int], int]:
95+
def data(self) -> Dict[Tuple[int, Any], Any]:
9896
data = {}
9997
for i, l in enumerate(self.learners):
10098
data.update({(i, p): v for p, v in l.data.items()})
10199
return data
102100

103101
@property
104-
def pending_points(self) -> Set[Tuple[int, int]]:
102+
def pending_points(self) -> Set[Tuple[int, Any]]:
105103
pending_points = set()
106104
for i, l in enumerate(self.learners):
107105
pending_points.update({(i, p) for p in l.pending_points})
@@ -140,7 +138,9 @@ def strategy(self, strategy):
140138
' strategy="npoints", or strategy="cycle" is implemented.'
141139
)
142140

143-
def _ask_and_tell_based_on_loss_improvements(self, n: int) -> Any:
141+
def _ask_and_tell_based_on_loss_improvements(
142+
self, n: int
143+
) -> Tuple[List[Tuple[int, Any]], List[float]]:
144144
selected = [] # tuples ((learner_index, point), loss_improvement)
145145
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
146146
for _ in range(n):
@@ -165,11 +165,7 @@ def _ask_and_tell_based_on_loss_improvements(self, n: int) -> Any:
165165

166166
def _ask_and_tell_based_on_loss(
167167
self, n: int
168-
) -> Union[
169-
Tuple[List[Tuple[int, float]], List[np.float64]],
170-
Tuple[List[Union[Tuple[int, int], Tuple[int, float]]], List[float]],
171-
Tuple[List[Tuple[int, int]], List[float]],
172-
]:
168+
) -> Tuple[List[Tuple[int, Any]], List[float]]:
173169
selected = [] # tuples ((learner_index, point), loss_improvement)
174170
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
175171
for _ in range(n):
@@ -192,11 +188,7 @@ def _ask_and_tell_based_on_loss(
192188

193189
def _ask_and_tell_based_on_npoints(
194190
self, n: int
195-
) -> Union[
196-
Tuple[List[Union[Tuple[np.int64, int], Tuple[np.int64, float]]], List[float]],
197-
Tuple[List[Tuple[np.int64, float]], List[np.float64]],
198-
Tuple[List[Tuple[np.int64, int]], List[float]],
199-
]:
191+
) -> Tuple[List[Tuple[int, Any]], List[float]]:
200192
selected = [] # tuples ((learner_index, point), loss_improvement)
201193
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
202194
for _ in range(n):
@@ -214,11 +206,7 @@ def _ask_and_tell_based_on_npoints(
214206

215207
def _ask_and_tell_based_on_cycle(
216208
self, n: int
217-
) -> Union[
218-
Tuple[List[Tuple[int, float]], List[np.float64]],
219-
Tuple[List[Union[Tuple[int, int], Tuple[int, float]]], List[float]],
220-
Tuple[List[Tuple[int, int]], List[float]],
221-
]:
209+
) -> Tuple[List[Tuple[int, Any]], List[float]]:
222210
points, loss_improvements = [], []
223211
for _ in range(n):
224212
index = next(self._cycle)
@@ -229,7 +217,9 @@ def _ask_and_tell_based_on_cycle(
229217

230218
return points, loss_improvements
231219

232-
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[Any], List[float]]:
220+
def ask(
221+
self, n: int, tell_pending: bool = True
222+
) -> Tuple[List[Tuple[int, Any]], List[float]]:
233223
"""Chose points for learners."""
234224
if n == 0:
235225
return [], []
@@ -240,26 +230,20 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[Any], List[float]
240230
else:
241231
return self._ask_and_tell(n)
242232

243-
def tell(
244-
self,
245-
x: Any,
246-
y: Union[int, np.float64, float, Tuple[int, int], Tuple[np.int64, int]],
247-
) -> None:
233+
def tell(self, x: Tuple[int, Any], y: Any,) -> None:
248234
index, x = x
249235
self._ask_cache.pop(index, None)
250236
self._loss.pop(index, None)
251237
self._pending_loss.pop(index, None)
252238
self.learners[index].tell(x, y)
253239

254-
def tell_pending(self, x: Any) -> None:
240+
def tell_pending(self, x: Tuple[int, Any]) -> None:
255241
index, x = x
256242
self._ask_cache.pop(index, None)
257243
self._loss.pop(index, None)
258244
self.learners[index].tell_pending(x)
259245

260-
def _losses(
261-
self, real: bool = True
262-
) -> Union[List[float], List[np.float64], List[Union[float, np.float64]]]:
246+
def _losses(self, real: bool = True) -> List[float]:
263247
losses = []
264248
loss_dict = self._loss if real else self._pending_loss
265249

@@ -271,7 +255,7 @@ def _losses(
271255
return losses
272256

273257
@cache_latest
274-
def loss(self, real: bool = True) -> Union[np.float64, float]:
258+
def loss(self, real: bool = True) -> Union[float]:
275259
losses = self._losses(real)
276260
return max(losses)
277261

0 commit comments

Comments
 (0)