Skip to content

Commit 4d54b6b

Browse files
committed
run test_average_learner.py without type warnings
1 parent 215ff2f commit 4d54b6b

File tree

4 files changed

+17
-17
lines changed

4 files changed

+17
-17
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def from_product(
371371
f,
372372
learner_type: BaseLearner,
373373
learner_kwargs: Dict[str, Any],
374-
combos: Dict[str, Iterable[Any]],
374+
combos: Dict[str, Sequence[Any]],
375375
) -> "BalancingLearner":
376376
"""Create a `BalancingLearner` with learners of all combinations of
377377
named variables’ values. The `cdims` will be set correctly, so calling

adaptive/learner/learnerND.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def npoints(self) -> int:
383383
return len(self.data)
384384

385385
@property
386-
def vdim(self):
386+
def vdim(self) -> int:
387387
"""Length of the output of ``learner.function``.
388388
If the output is unsized (when it's a scalar)
389389
then `vdim = 1`.
@@ -399,10 +399,10 @@ def vdim(self):
399399
return self._vdim if self._vdim is not None else 1
400400

401401
@property
402-
def bounds_are_done(self):
402+
def bounds_are_done(self) -> bool:
403403
return all(p in self.data for p in self._bounds_points)
404404

405-
def _ip(self):
405+
def _ip(self) -> interpolate.LinearNDInterpolator:
406406
"""A `scipy.interpolate.LinearNDInterpolator` instance
407407
containing the learner's data."""
408408
# XXX: take our own triangulation into account when generating the _ip
@@ -427,7 +427,7 @@ def tri(self) -> Optional[Triangulation]:
427427
return self._tri
428428

429429
@property
430-
def values(self):
430+
def values(self) -> np.ndarray:
431431
"""Get the values from `data` as a numpy array."""
432432
return np.array(list(self.data.values()), dtype=float)
433433

adaptive/runner.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -237,20 +237,20 @@ def __init__(
237237
# Error handling attributes
238238
self.retries = retries
239239
self.raise_if_retries_exceeded = raise_if_retries_exceeded
240-
self._to_retry = {}
241-
self._tracebacks = {}
240+
self._to_retry: Dict[int, int] = {}
241+
self._tracebacks: Dict[int, str] = {}
242242

243-
self._id_to_point = {}
243+
self._id_to_point: Dict[int, Any] = {}
244244
self._next_id = functools.partial(
245245
next, itertools.count()
246246
) # some unique id to be associated with each point
247247

248248
def _get_max_tasks(self) -> int:
249249
return self._max_tasks or _get_ncores(self.executor)
250250

251-
def _do_raise(self, e, i):
252-
tb = self._tracebacks[i]
253-
x = self._id_to_point[i]
251+
def _do_raise(self, e, pid):
252+
tb = self._tracebacks[pid]
253+
x = self._id_to_point[pid]
254254
raise RuntimeError(
255255
"An error occured while evaluating "
256256
f'"learner.function({x})". '
@@ -379,7 +379,7 @@ def _cleanup(self) -> None:
379379
self.end_time = time.time()
380380

381381
@property
382-
def failed(self):
382+
def failed(self) -> Set[Any]:
383383
"""Set of points that failed ``runner.retries`` times."""
384384
return set(self._tracebacks) - set(self._to_retry)
385385

@@ -398,15 +398,15 @@ def _submit(self, x):
398398
pass
399399

400400
@property
401-
def tracebacks(self):
401+
def tracebacks(self) -> List[Tuple[int, str]]:
402402
return [(self._id_to_point[pid], tb) for pid, tb in self._tracebacks.items()]
403403

404404
@property
405-
def to_retry(self):
405+
def to_retry(self) -> List[Tuple[int, int]]:
406406
return [(self._id_to_point[pid], n) for pid, n in self._to_retry.items()]
407407

408408
@property
409-
def pending_points(self):
409+
def pending_points(self) -> List[Tuple[Future, Any]]:
410410
return [
411411
(fut, self._id_to_point[pid]) for fut, pid in self._pending_tasks.items()
412412
]

adaptive/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import pickle
66
from contextlib import contextmanager
77
from itertools import product
8-
from typing import Any, Callable, Dict, Iterable, Iterator
8+
from typing import Any, Callable, Dict, Iterator, Sequence
99

1010
from atomicwrites import AtomicWriter
1111

1212

13-
def named_product(**items: Dict[str, Iterable[Any]]):
13+
def named_product(**items: Dict[str, Sequence[Any]]):
1414
names = items.keys()
1515
vals = items.values()
1616
return [dict(zip(names, res)) for res in product(*vals)]

0 commit comments

Comments
 (0)