Skip to content

Commit e328e29

Browse files
committed
type fixes
1 parent 9c86827 commit e328e29

12 files changed

+45
-35
lines changed

adaptive/learner/average_learner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class AverageLearner(BaseLearner):
3535

3636
def __init__(
3737
self,
38-
function: Callable[[BaseLearner], float],
38+
function: Callable[["AverageLearner"], float],
3939
atol: Optional[float] = None,
4040
rtol: Optional[float] = None,
4141
min_npoints: int = 2,
@@ -49,7 +49,7 @@ def __init__(
4949

5050
self.data = {}
5151
self.pending_points = set()
52-
self.function = function
52+
self.function = function # type: ignore
5353
self.atol = atol
5454
self.rtol = rtol
5555
self.npoints = 0

adaptive/learner/balancing_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(
101101
# Naively we would make 'function' a method, but this causes problems
102102
# when using executors from 'concurrent.futures' because we have to
103103
# pickle the whole learner.
104-
self.function = partial(dispatch, [l.function for l in self.learners])
104+
self.function = partial(dispatch, [l.function for l in self.learners]) # type: ignore
105105

106106
self._ask_cache = {}
107107
self._loss = {}

adaptive/learner/base_learner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,13 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
8282
"""
8383

8484
data: dict
85-
npoints: int
8685
pending_points: set
86+
function: Callable
87+
88+
@property
89+
@abc.abstractmethod
90+
def npoints(self) -> int:
91+
"""Number of learned points."""
8792

8893
def tell(self, x: Any, y) -> None:
8994
"""Tell the learner about a single value.
@@ -149,7 +154,7 @@ def _get_data(self):
149154
pass
150155

151156
@abc.abstractmethod
152-
def _set_data(self):
157+
def _set_data(self, data: Any):
153158
pass
154159

155160
def copy_from(self, other):

adaptive/learner/integrator_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def __init__(self, function: Callable, bounds: Tuple[int, int], tol: float) -> N
379379
plot : hv.Scatter
380380
Plots all the points that are evaluated.
381381
"""
382-
self.function = function
382+
self.function = function # type: ignore
383383
self.bounds = bounds
384384
self.tol = tol
385385
self.max_ivals = 1000

adaptive/learner/learner1D.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@
2626
from adaptive.notebook_integration import ensure_holoviews
2727
from adaptive.utils import cache_latest
2828

29+
Point = Tuple[float, float]
30+
2931

3032
@uses_nth_neighbors(0)
31-
def uniform_loss(xs: Tuple[float, float], ys: Tuple[float, float]) -> float:
33+
def uniform_loss(xs: Point, ys: Any) -> float:
3234
"""Loss function that samples the domain uniformly.
3335
3436
Works with `~adaptive.Learner1D` only.
@@ -49,8 +51,8 @@ def uniform_loss(xs: Tuple[float, float], ys: Tuple[float, float]) -> float:
4951

5052
@uses_nth_neighbors(0)
5153
def default_loss(
52-
xs: Tuple[float, float],
53-
ys: Union[Tuple[Iterable[float], Iterable[float]], Tuple[float, float]],
54+
xs: Point,
55+
ys: Union[Tuple[Iterable[float], Iterable[float]], Point],
5456
) -> float:
5557
"""Calculate loss on a single interval.
5658
@@ -60,8 +62,8 @@ def default_loss(
6062
"""
6163
dx = xs[1] - xs[0]
6264
if isinstance(ys[0], collections.abc.Iterable):
63-
dy = [abs(a - b) for a, b in zip(*ys)]
64-
return np.hypot(dx, dy).max()
65+
dy_vec = [abs(a - b) for a, b in zip(*ys)]
66+
return np.hypot(dx, dy_vec).max()
6567
else:
6668
dy = ys[1] - ys[0]
6769
return np.hypot(dx, dy)
@@ -200,7 +202,7 @@ def __init__(
200202
bounds: Tuple[float, float],
201203
loss_per_interval: Optional[Callable] = None,
202204
) -> None:
203-
self.function = function
205+
self.function = function # type: ignore
204206

205207
if hasattr(loss_per_interval, "nth_neighbors"):
206208
self.nth_neighbors = loss_per_interval.nth_neighbors
@@ -238,7 +240,7 @@ def __init__(
238240

239241
self.bounds = list(bounds)
240242

241-
self._vdim = None
243+
self._vdim: Optional[int] = None
242244

243245
@property
244246
def vdim(self) -> int:
@@ -565,7 +567,8 @@ def _ask_points_without_adding(self, n: int) -> Any:
565567
# Add bound intervals to quals if bounds were missing.
566568
if len(self.data) + len(self.pending_points) == 0:
567569
# We don't have any points, so return a linspace with 'n' points.
568-
return np.linspace(*self.bounds, n).tolist(), [np.inf] * n
570+
a, b = self.bounds
571+
return np.linspace(a, b, n).tolist(), [np.inf] * n
569572

570573
quals = loss_manager(self._scale[0])
571574
if len(missing_bounds) > 0:
@@ -601,7 +604,7 @@ def _ask_points_without_adding(self, n: int) -> Any:
601604
quals[(*xs, n + 1)] = loss_qual * n / (n + 1)
602605

603606
points = list(
604-
itertools.chain.from_iterable(linspace(*ival, n) for (*ival, n) in quals)
607+
itertools.chain.from_iterable(linspace(a, b, n) for ((a, b), n) in quals)
605608
)
606609

607610
loss_improvements = list(
@@ -665,7 +668,8 @@ def _get_data(self) -> Dict[float, float]:
665668

666669
def _set_data(self, data: Dict[float, float]) -> None:
667670
if data:
668-
self.tell_many(*zip(*data.items()))
671+
xs, ys = zip(*data.items())
672+
self.tell_many(xs, ys)
669673

670674
def __getstate__(self):
671675
return (

adaptive/learner/learner2D.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def __init__(
372372

373373
self._bounds_points = list(itertools.product(*bounds))
374374
self._stack.update({p: np.inf for p in self._bounds_points})
375-
self.function = function
375+
self.function = function # type: ignore
376376
self._ip = self._ip_combined = None
377377

378378
self.stack_size = 10
@@ -600,7 +600,7 @@ def _fill_stack(
600600

601601
def ask(
602602
self, n: int, tell_pending: bool = True
603-
) -> Tuple[List[Tuple[float, float]], List[float]]:
603+
) -> Tuple[List[Union[Tuple[float, float], np.array]], List[float]]:
604604
# Even if tell_pending is False we add the point such that _fill_stack
605605
# will return new points, later we remove these points if needed.
606606
points = list(self._stack.keys())

adaptive/learner/learnerND.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ class LearnerND(BaseLearner):
305305

306306
def __init__(
307307
self,
308-
func: Callable,
308+
function: Callable,
309309
bounds: Union[Sequence[Tuple[float, float]], ConvexHull],
310310
loss_per_simplex: Optional[Callable] = None,
311311
) -> None:
@@ -339,8 +339,8 @@ def __init__(
339339

340340
self.ndim = len(self._bbox)
341341

342-
self.function = func
343-
self._tri = None
342+
self.function = function # type: ignore
343+
self._tri: Optional[Triangulation] = None
344344
self._losses: Dict[Simplex, float] = dict()
345345

346346
self._pending_to_simplex: Dict[Point, Simplex] = dict() # vertex → simplex
@@ -455,6 +455,7 @@ def tell(self, point: Tuple[float, ...], value: Union[float, np.ndarray]) -> Non
455455
self._update_range(value)
456456
if tri is not None:
457457
simplex = self._pending_to_simplex.get(point)
458+
assert self.tri is not None
458459
if simplex is not None and not self._simplex_exists(simplex):
459460
simplex = None
460461
to_delete, to_add = tri.add_point(point, simplex, transform=self._transform)

adaptive/learner/sequence_learner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class _IgnoreFirstArgument:
1818
"""
1919

2020
def __init__(self, function: Callable) -> None:
21-
self.function = function
21+
self.function = function # type: ignore
2222

2323
def __call__(
2424
self, index_point: Tuple[int, Union[float, np.ndarray]], *args, **kwargs
@@ -62,7 +62,7 @@ class SequenceLearner(BaseLearner):
6262

6363
def __init__(self, function: Callable, sequence: Iterable) -> None:
6464
self._original_function = function
65-
self.function = _IgnoreFirstArgument(function)
65+
self.function = _IgnoreFirstArgument(function) # type: ignore
6666
self._to_do_indices = SortedSet({i for i, _ in enumerate(sequence)})
6767
self._ntotal = len(sequence)
6868
self.sequence = copy(sequence)

adaptive/learner/skopt_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class SKOptLearner(Optimizer, BaseLearner):
2525
"""
2626

2727
def __init__(self, function: Callable, **kwargs) -> None:
28-
self.function = function
28+
self.function = function # type: ignore
2929
self.pending_points = set()
3030
self.data = collections.OrderedDict()
3131
super().__init__(**kwargs)

adaptive/learner/triangulation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
from numpy.linalg import matrix_rank, norm, slogdet, solve
2828

2929
SimplexPoints = Union[List[Tuple[float, ...]], ndarray]
30-
Simplex = Union[Iterable[numbers.Integral], ndarray]
30+
Simplex = Union[Sequence[numbers.Integral], ndarray]
3131
Point = Union[Tuple[float, ...], ndarray]
32-
Points = Union[Sequence[Point], ndarray]
32+
Points = Union[Sequence[Tuple[float, ...]], ndarray]
3333

3434

3535
def fast_norm(v: Union[Tuple[float, ...], ndarray]) -> float:
@@ -168,7 +168,7 @@ def fast_det(matrix: ndarray) -> float:
168168
return ndet(matrix)
169169

170170

171-
def circumsphere(pts: ndarray) -> Tuple[Tuple[float, ...], float]:
171+
def circumsphere(pts: Simplex) -> Tuple[Tuple[float, ...], float]:
172172
"""Compute the center and radius of a N dimension sphere which touches each point in pts.
173173
174174
Parameters

0 commit comments

Comments
 (0)