Skip to content

Commit 215ff2f

Browse files
committed
typing fixes
1 parent a87319c commit 215ff2f

File tree

5 files changed

+62
-22
lines changed

5 files changed

+62
-22
lines changed

adaptive/_version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import os
44
import subprocess
55
from collections import namedtuple
6-
from typing import Dict
6+
from typing import Dict, List
77

88
from setuptools.command.build_py import build_py as build_py_orig
99
from setuptools.command.sdist import sdist as sdist_orig
1010

1111
Version = namedtuple("Version", ("release", "dev", "labels"))
1212

1313
# No public API
14-
__all__ = []
14+
__all__: List[str] = []
1515

1616
package_root = os.path.dirname(os.path.realpath(__file__))
1717
package_name = os.path.basename(package_root)

adaptive/learner/average_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class AverageLearner(BaseLearner):
3535

3636
def __init__(
3737
self,
38-
function: Callable,
38+
function: Callable[[BaseLearner], float],
3939
atol: Optional[float] = None,
4040
rtol: Optional[float] = None,
4141
min_npoints: int = 2,

adaptive/learner/balancing_learner.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,18 @@
44
from contextlib import suppress
55
from functools import partial
66
from operator import itemgetter
7-
from typing import Any, Callable, Dict, List, Set, Tuple, Union
7+
from typing import (
8+
Any,
9+
Callable,
10+
Dict,
11+
List,
12+
Literal,
13+
Optional,
14+
Sequence,
15+
Set,
16+
Tuple,
17+
Union,
18+
)
819

920
import numpy as np
1021

@@ -18,6 +29,14 @@ def dispatch(child_functions: List[Callable], arg: Any) -> Union[Any]:
1829
return child_functions[index](x)
1930

2031

32+
STRATEGY_TYPE = Literal["loss_improvements", "loss", "npoints", "cycle"]
33+
34+
CDIMS_TYPE = Union[
35+
Sequence[Dict[str, Any]],
36+
Tuple[Sequence[str], Sequence[Tuple[Any, ...]]],
37+
]
38+
39+
2140
class BalancingLearner(BaseLearner):
2241
r"""Choose the optimal points from a set of learners.
2342
@@ -70,7 +89,11 @@ class BalancingLearner(BaseLearner):
7089
"""
7190

7291
def __init__(
73-
self, learners: List[BaseLearner], *, cdims=None, strategy="loss_improvements"
92+
self,
93+
learners: List[BaseLearner],
94+
*,
95+
cdims: Optional[CDIMS_TYPE] = None,
96+
strategy: STRATEGY_TYPE = "loss_improvements"
7497
) -> None:
7598
self.learners = learners
7699

@@ -89,7 +112,7 @@ def __init__(
89112
"A BalacingLearner can handle only one type" " of learners."
90113
)
91114

92-
self.strategy = strategy
115+
self.strategy: STRATEGY_TYPE = strategy
93116

94117
@property
95118
def data(self) -> Dict[Tuple[int, Any], Any]:
@@ -110,7 +133,7 @@ def npoints(self) -> int:
110133
return sum(l.npoints for l in self.learners)
111134

112135
@property
113-
def strategy(self):
136+
def strategy(self) -> STRATEGY_TYPE:
114137
"""Can be either 'loss_improvements' (default), 'loss', 'npoints', or
115138
'cycle'. The points that the `BalancingLearner` choses can be either
116139
based on: the best 'loss_improvements', the smallest total 'loss' of
@@ -121,7 +144,7 @@ def strategy(self):
121144
return self._strategy
122145

123146
@strategy.setter
124-
def strategy(self, strategy):
147+
def strategy(self, strategy: STRATEGY_TYPE) -> None:
125148
self._strategy = strategy
126149
if strategy == "loss_improvements":
127150
self._ask_and_tell = self._ask_and_tell_based_on_loss_improvements
@@ -255,11 +278,16 @@ def _losses(self, real: bool = True) -> List[float]:
255278
return losses
256279

257280
@cache_latest
258-
def loss(self, real: bool = True) -> Union[float]:
281+
def loss(self, real: bool = True) -> float:
259282
losses = self._losses(real)
260283
return max(losses)
261284

262-
def plot(self, cdims=None, plotter=None, dynamic=True):
285+
def plot(
286+
self,
287+
cdims: Optional[CDIMS_TYPE] = None,
288+
plotter: Optional[Callable[[BaseLearner], Any]] = None,
289+
dynamic: bool = True,
290+
):
263291
"""Returns a DynamicMap with sliders.
264292
265293
Parameters
@@ -332,14 +360,18 @@ def plot_function(*args):
332360
vals = {d.name: d.values for d in dm.dimensions() if d.values}
333361
return hv.HoloMap(dm.select(**vals))
334362

335-
def remove_unfinished(self):
363+
def remove_unfinished(self) -> None:
336364
"""Remove uncomputed data from the learners."""
337365
for learner in self.learners:
338366
learner.remove_unfinished()
339367

340368
@classmethod
341369
def from_product(
342-
cls, f, learner_type, learner_kwargs, combos
370+
cls,
371+
f,
372+
learner_type: BaseLearner,
373+
learner_kwargs: Dict[str, Any],
374+
combos: Dict[str, Iterable[Any]],
343375
) -> "BalancingLearner":
344376
"""Create a `BalancingLearner` with learners of all combinations of
345377
named variables’ values. The `cdims` will be set correctly, so calling
@@ -387,7 +419,11 @@ def from_product(
387419
learners.append(learner)
388420
return cls(learners, cdims=arguments)
389421

390-
def save(self, fname: Callable, compress: bool = True) -> None:
422+
def save(
423+
self,
424+
fname: Union[Callable[[BaseLearner], str], Sequence[str]],
425+
compress: bool = True,
426+
) -> None:
391427
"""Save the data of the child learners into pickle files
392428
in a directory.
393429
@@ -425,7 +461,11 @@ def save(self, fname: Callable, compress: bool = True) -> None:
425461
for l in self.learners:
426462
l.save(fname(l), compress=compress)
427463

428-
def load(self, fname: Callable, compress: bool = True) -> None:
464+
def load(
465+
self,
466+
fname: Union[Callable[[BaseLearner], str], Sequence[str]],
467+
compress: bool = True,
468+
) -> None:
429469
"""Load the data of the child learners from pickle files
430470
in a directory.
431471
@@ -449,20 +489,20 @@ def load(self, fname: Callable, compress: bool = True) -> None:
449489
for l in self.learners:
450490
l.load(fname(l), compress=compress)
451491

452-
def _get_data(self):
492+
def _get_data(self) -> List[Any]:
453493
return [l._get_data() for l in self.learners]
454494

455-
def _set_data(self, data):
495+
def _set_data(self, data: List[Any]):
456496
for l, _data in zip(self.learners, data):
457497
l._set_data(_data)
458498

459-
def __getstate__(self):
499+
def __getstate__(self) -> Tuple[List[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]:
460500
return (
461501
self.learners,
462502
self._cdims_default,
463503
self.strategy,
464504
)
465505

466-
def __setstate__(self, state):
506+
def __setstate__(self, state: Tuple[List[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]):
467507
learners, cdims, strategy = state
468508
self.__init__(learners, cdims=cdims, strategy=strategy)

adaptive/learner/triangulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def fast_2d_circumcircle(points: Sequence[Point]) -> Tuple[Tuple[float, float],
8787
tuple
8888
(center point : tuple(float), radius: float)
8989
"""
90-
points = array(points)
90+
points = array(points, dtype=float)
9191
# transform to relative coordinates
9292
pts = points[1:] - points[0]
9393

@@ -113,7 +113,7 @@ def fast_2d_circumcircle(points: Sequence[Point]) -> Tuple[Tuple[float, float],
113113
def fast_3d_circumcircle(
114114
points: Sequence[Point],
115115
) -> Tuple[Tuple[float, float, float], float]:
116-
"""Compute the center and radius of the circumscribed shpere of a simplex.
116+
"""Compute the center and radius of the circumscribed sphere of a simplex.
117117
118118
Parameters
119119
----------

adaptive/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import pickle
66
from contextlib import contextmanager
77
from itertools import product
8-
from typing import Any, Callable, Iterator
8+
from typing import Any, Callable, Dict, Iterable, Iterator
99

1010
from atomicwrites import AtomicWriter
1111

1212

13-
def named_product(**items):
13+
def named_product(**items: Dict[str, Iterable[Any]]):
1414
names = items.keys()
1515
vals = items.values()
1616
return [dict(zip(names, res)) for res in product(*vals)]

0 commit comments

Comments
 (0)