Skip to content

Commit fcae8e0

Browse files
committed
add type annotations for adaptive/learner/balancing_learner.py
1 parent 1abd243 commit fcae8e0

File tree

1 file changed

+60
-16
lines changed

1 file changed

+60
-16
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,25 @@
44
from contextlib import suppress
55
from functools import partial
66
from operator import itemgetter
7+
from typing import Any, Callable, Dict, List, Set, Tuple, Union
78

89
import numpy as np
10+
from numpy import float64, int64
911

12+
from adaptive.learner.average_learner import AverageLearner
1013
from adaptive.learner.base_learner import BaseLearner
14+
from adaptive.learner.learner1D import Learner1D
15+
from adaptive.learner.learner2D import Learner2D
16+
from adaptive.learner.learnerND import LearnerND
17+
from adaptive.learner.sequence_learner import SequenceLearner, _IgnoreFirstArgument
1118
from adaptive.notebook_integration import ensure_holoviews
1219
from adaptive.utils import cache_latest, named_product, restore
1320

1421

15-
def dispatch(child_functions, arg):
22+
def dispatch(
23+
child_functions: Union[List[Callable], List[partial], List[_IgnoreFirstArgument]],
24+
arg: Any,
25+
) -> Union[int, float64, float]:
1626
index, x = arg
1727
return child_functions[index](x)
1828

@@ -68,7 +78,19 @@ class BalancingLearner(BaseLearner):
6878
behave in an undefined way. Change the `strategy` in that case.
6979
"""
7080

71-
def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
81+
def __init__(
82+
self,
83+
learners: Union[
84+
List[SequenceLearner],
85+
List[AverageLearner],
86+
List[Learner2D],
87+
List[Learner1D],
88+
List[LearnerND],
89+
],
90+
*,
91+
cdims=None,
92+
strategy="loss_improvements"
93+
) -> None:
7294
self.learners = learners
7395

7496
# Naively we would make 'function' a method, but this causes problems
@@ -89,21 +111,21 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
89111
self.strategy = strategy
90112

91113
@property
92-
def data(self):
114+
def data(self) -> Dict[Tuple[int, int], int]:
93115
data = {}
94116
for i, l in enumerate(self.learners):
95117
data.update({(i, p): v for p, v in l.data.items()})
96118
return data
97119

98120
@property
99-
def pending_points(self):
121+
def pending_points(self) -> Set[Tuple[int, int]]:
100122
pending_points = set()
101123
for i, l in enumerate(self.learners):
102124
pending_points.update({(i, p) for p in l.pending_points})
103125
return pending_points
104126

105127
@property
106-
def npoints(self):
128+
def npoints(self) -> int:
107129
return sum(l.npoints for l in self.learners)
108130

109131
@property
@@ -135,7 +157,7 @@ def strategy(self, strategy):
135157
' strategy="npoints", or strategy="cycle" is implemented.'
136158
)
137159

138-
def _ask_and_tell_based_on_loss_improvements(self, n):
160+
def _ask_and_tell_based_on_loss_improvements(self, n: int) -> Any:
139161
selected = [] # tuples ((learner_index, point), loss_improvement)
140162
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
141163
for _ in range(n):
@@ -158,7 +180,13 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
158180
points, loss_improvements = map(list, zip(*selected))
159181
return points, loss_improvements
160182

161-
def _ask_and_tell_based_on_loss(self, n):
183+
def _ask_and_tell_based_on_loss(
184+
self, n: int
185+
) -> Union[
186+
Tuple[List[Tuple[int, float]], List[float64]],
187+
Tuple[List[Union[Tuple[int, int], Tuple[int, float]]], List[float]],
188+
Tuple[List[Tuple[int, int]], List[float]],
189+
]:
162190
selected = [] # tuples ((learner_index, point), loss_improvement)
163191
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
164192
for _ in range(n):
@@ -179,7 +207,13 @@ def _ask_and_tell_based_on_loss(self, n):
179207
points, loss_improvements = map(list, zip(*selected))
180208
return points, loss_improvements
181209

182-
def _ask_and_tell_based_on_npoints(self, n):
210+
def _ask_and_tell_based_on_npoints(
211+
self, n: int
212+
) -> Union[
213+
Tuple[List[Union[Tuple[int64, int], Tuple[int64, float]]], List[float]],
214+
Tuple[List[Tuple[int64, float]], List[float64]],
215+
Tuple[List[Tuple[int64, int]], List[float]],
216+
]:
183217
selected = [] # tuples ((learner_index, point), loss_improvement)
184218
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
185219
for _ in range(n):
@@ -195,7 +229,13 @@ def _ask_and_tell_based_on_npoints(self, n):
195229
points, loss_improvements = map(list, zip(*selected))
196230
return points, loss_improvements
197231

198-
def _ask_and_tell_based_on_cycle(self, n):
232+
def _ask_and_tell_based_on_cycle(
233+
self, n: int
234+
) -> Union[
235+
Tuple[List[Tuple[int, float]], List[float64]],
236+
Tuple[List[Union[Tuple[int, int], Tuple[int, float]]], List[float]],
237+
Tuple[List[Tuple[int, int]], List[float]],
238+
]:
199239
points, loss_improvements = [], []
200240
for _ in range(n):
201241
index = next(self._cycle)
@@ -206,7 +246,7 @@ def _ask_and_tell_based_on_cycle(self, n):
206246

207247
return points, loss_improvements
208248

209-
def ask(self, n, tell_pending=True):
249+
def ask(self, n: int, tell_pending: bool = True) -> Any:
210250
"""Chose points for learners."""
211251
if n == 0:
212252
return [], []
@@ -217,20 +257,24 @@ def ask(self, n, tell_pending=True):
217257
else:
218258
return self._ask_and_tell(n)
219259

220-
def tell(self, x, y):
260+
def tell(
261+
self, x: Any, y: Union[int, float64, float, Tuple[int, int], Tuple[int64, int]]
262+
) -> None:
221263
index, x = x
222264
self._ask_cache.pop(index, None)
223265
self._loss.pop(index, None)
224266
self._pending_loss.pop(index, None)
225267
self.learners[index].tell(x, y)
226268

227-
def tell_pending(self, x):
269+
def tell_pending(self, x: Any) -> None:
228270
index, x = x
229271
self._ask_cache.pop(index, None)
230272
self._loss.pop(index, None)
231273
self.learners[index].tell_pending(x)
232274

233-
def _losses(self, real=True):
275+
def _losses(
276+
self, real: bool = True
277+
) -> Union[List[float], List[float64], List[Union[float, float64]]]:
234278
losses = []
235279
loss_dict = self._loss if real else self._pending_loss
236280

@@ -242,7 +286,7 @@ def _losses(self, real=True):
242286
return losses
243287

244288
@cache_latest
245-
def loss(self, real=True):
289+
def loss(self, real: bool = True) -> Union[float64, float]:
246290
losses = self._losses(real)
247291
return max(losses)
248292

@@ -372,7 +416,7 @@ def from_product(cls, f, learner_type, learner_kwargs, combos):
372416
learners.append(learner)
373417
return cls(learners, cdims=arguments)
374418

375-
def save(self, fname, compress=True):
419+
def save(self, fname: Callable, compress: bool = True) -> None:
376420
"""Save the data of the child learners into pickle files
377421
in a directory.
378422
@@ -410,7 +454,7 @@ def save(self, fname, compress=True):
410454
for l in self.learners:
411455
l.save(fname(l), compress=compress)
412456

413-
def load(self, fname, compress=True):
457+
def load(self, fname: Callable, compress: bool = True) -> None:
414458
"""Load the data of the child learners from pickle files
415459
in a directory.
416460

0 commit comments

Comments
 (0)