Skip to content

Commit fb457d5

Browse files
committed
more type fixes
1 parent 6db3e75 commit fb457d5

File tree

5 files changed

+13
-11
lines changed

5 files changed

+13
-11
lines changed

adaptive/learner/base_learner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def remove_unfinished(self):
118118
pass
119119

120120
@abc.abstractmethod
121-
def loss(self, real=True):
121+
def loss(self, real: bool = True):
122122
"""Return the loss for the current state of the learner.
123123
124124
Parameters
@@ -130,7 +130,7 @@ def loss(self, real=True):
130130
"""
131131

132132
@abc.abstractmethod
133-
def ask(self, n, tell_pending=True):
133+
def ask(self, n: int, tell_pending: bool = True):
134134
"""Choose the next 'n' points to evaluate.
135135
136136
Parameters

adaptive/learner/learner2D.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,9 @@ def _fill_stack(
598598

599599
return points_new, losses_new
600600

601-
def ask(self, n: int, tell_pending: bool = True) -> Tuple[np.ndarray, np.ndarray]:
601+
def ask(
602+
self, n: int, tell_pending: bool = True
603+
) -> Tuple[List[Tuple[float, float]], List[float]]:
602604
# Even if tell_pending is False we add the point such that _fill_stack
603605
# will return new points, later we remove these points if needed.
604606
points = list(self._stack.keys())

adaptive/learner/learnerND.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def to_list(inp: float) -> List[float]:
3131
return [inp]
3232

3333

34-
def volume(simplex: List[Tuple[float, float]], ys: None = None) -> float:
34+
def volume(simplex: Simplex, ys: None = None) -> float:
3535
# Notice the parameter ys is there so you can use this volume method as
3636
# as loss function
3737
matrix = np.subtract(simplex[:-1], simplex[-1], dtype=float)
@@ -69,7 +69,7 @@ def uniform_loss(simplex: np.ndarray, values: np.ndarray, value_scale: float) ->
6969
return volume(simplex)
7070

7171

72-
def std_loss(simplex: np.ndarray, values: np.ndarray, value_scale: float) -> np.ndarray:
72+
def std_loss(simplex: Simplex, values: np.ndarray, value_scale: float) -> np.ndarray:
7373
"""
7474
Computes the loss of the simplex based on the standard deviation.
7575

adaptive/learner/triangulation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def fast_norm(v: Union[Tuple[float, ...], ndarray]) -> float:
5252

5353
def fast_2d_point_in_simplex(
5454
point: Point, simplex: SimplexPoints, eps: float = 1e-8
55-
) -> Union[np.bool_, bool]:
55+
) -> Union[bool, np.bool_]:
5656
(p0x, p0y), (p1x, p1y), (p2x, p2y) = simplex
5757
px, py = point
5858

@@ -68,7 +68,7 @@ def fast_2d_point_in_simplex(
6868

6969
def point_in_simplex(
7070
point: Point, simplex: SimplexPoints, eps: float = 1e-8
71-
) -> Union[np.bool_, bool]:
71+
) -> Union[bool, np.bool_]:
7272
if len(point) == 2:
7373
return fast_2d_point_in_simplex(point, simplex, eps)
7474

@@ -417,7 +417,7 @@ def get_reduced_simplex(
417417

418418
def point_in_simplex(
419419
self, point: Point, simplex: Simplex, eps: float = 1e-8
420-
) -> Union[np.bool_, bool]:
420+
) -> Union[bool, np.bool_]:
421421
vertices = self.get_vertices(simplex)
422422
return point_in_simplex(point, vertices, eps)
423423

adaptive/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import gzip
44
import os
55
import pickle
6-
from contextlib import contextmanager
6+
from contextlib import _GeneratorContextManager, contextmanager
77
from itertools import product
8-
from typing import Any, Callable, Dict, Iterator, Sequence
8+
from typing import Any, Callable, Dict, Sequence
99

1010
from atomicwrites import AtomicWriter
1111

@@ -17,7 +17,7 @@ def named_product(**items: Dict[str, Sequence[Any]]):
1717

1818

1919
@contextmanager
20-
def restore(*learners) -> Iterator[None]:
20+
def restore(*learners) -> _GeneratorContextManager:
2121
states = [learner.__getstate__() for learner in learners]
2222
try:
2323
yield

0 commit comments

Comments
 (0)