Skip to content

Commit 4055f1a

Browse files
committed
typing fixes
1 parent 8682531 commit 4055f1a

File tree

4 files changed

+28
-18
lines changed

4 files changed

+28
-18
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import numbers
23
from collections import defaultdict
34
from collections.abc import Iterable
45
from contextlib import suppress
@@ -210,12 +211,12 @@ def _ask_and_tell_based_on_loss(
210211
return points, loss_improvements
211212

212213
def _ask_and_tell_based_on_npoints(
213-
self, n: int
214-
) -> Tuple[List[Tuple[int, Any]], List[float]]:
214+
self, n: numbers.Integral
215+
) -> Tuple[List[Tuple[numbers.Integral, Any]], List[float]]:
215216
selected = [] # tuples ((learner_index, point), loss_improvement)
216217
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
217218
for _ in range(n):
218-
index = int(np.argmin(total_points))
219+
index = np.argmin(total_points)
219220
# Take the points from the cache
220221
if index not in self._ask_cache:
221222
self._ask_cache[index] = self.learners[index].ask(n=1)
@@ -229,7 +230,7 @@ def _ask_and_tell_based_on_npoints(
229230

230231
def _ask_and_tell_based_on_cycle(
231232
self, n: int
232-
) -> Tuple[List[Tuple[int, Any]], List[float]]:
233+
) -> Tuple[List[Tuple[numbers.Integral, Any]], List[float]]:
233234
points, loss_improvements = [], []
234235
for _ in range(n):
235236
index = next(self._cycle)
@@ -242,7 +243,7 @@ def _ask_and_tell_based_on_cycle(
242243

243244
def ask(
244245
self, n: int, tell_pending: bool = True
245-
) -> Tuple[List[Tuple[int, Any]], List[float]]:
246+
) -> Tuple[List[Tuple[numbers.Integral, Any]], List[float]]:
246247
"""Chose points for learners."""
247248
if n == 0:
248249
return [], []
@@ -253,14 +254,14 @@ def ask(
253254
else:
254255
return self._ask_and_tell(n)
255256

256-
def tell(self, x: Tuple[int, Any], y: Any) -> None:
257+
def tell(self, x: Tuple[numbers.Integral, Any], y: Any) -> None:
257258
index, x = x
258259
self._ask_cache.pop(index, None)
259260
self._loss.pop(index, None)
260261
self._pending_loss.pop(index, None)
261262
self.learners[index].tell(x, y)
262263

263-
def tell_pending(self, x: Tuple[int, Any]) -> None:
264+
def tell_pending(self, x: Tuple[numbers.Integral, Any]) -> None:
264265
index, x = x
265266
self._ask_cache.pop(index, None)
266267
self._loss.pop(index, None)

adaptive/learner/learner1D.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections.abc
22
import itertools
33
import math
4+
import numbers
45
from copy import deepcopy
56
from typing import (
67
Any,
@@ -414,7 +415,9 @@ def _update_scale(self, x: float, y: Union[float, np.ndarray]) -> None:
414415
self._bbox[1][1] = max(self._bbox[1][1], y)
415416
self._scale[1] = self._bbox[1][1] - self._bbox[1][0]
416417

417-
def tell(self, x: float, y: Union[float, Sequence[float], np.ndarray]) -> None:
418+
def tell(
419+
self, x: float, y: Union[float, Sequence[numbers.Number], np.ndarray]
420+
) -> None:
418421
if x in self.data:
419422
# The point is already evaluated before
420423
return

adaptive/learner/learnerND.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import random
44
from collections import OrderedDict
55
from collections.abc import Iterable
6-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
6+
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
77

88
import numpy as np
99
import scipy.spatial
@@ -306,7 +306,7 @@ class LearnerND(BaseLearner):
306306
def __init__(
307307
self,
308308
func: Callable,
309-
bounds: Union[Tuple[Tuple[float, float], ...], ConvexHull],
309+
bounds: Union[Sequence[Tuple[float, float]], ConvexHull],
310310
loss_per_simplex: Optional[Callable] = None,
311311
) -> None:
312312
self._vdim = None

adaptive/learner/triangulation.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import collections.abc
2+
import numbers
23
from collections import Counter
34
from itertools import chain, combinations
45
from math import factorial, sqrt
56
from typing import Any, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union
67

8+
import numpy as np
79
import scipy.spatial
810
from numpy import abs as np_abs
911
from numpy import (
@@ -25,7 +27,7 @@
2527
from numpy.linalg import matrix_rank, norm, slogdet, solve
2628

2729
SimplexPoints = Union[List[Tuple[float, ...]], ndarray] # XXX: check if this is correct
28-
Simplex = Tuple[int, ...]
30+
Simplex = Union[Tuple[numbers.Integral, ...], ndarray]
2931
Point = Union[Tuple[float, ...], ndarray] # XXX: check if this is correct
3032

3133

@@ -49,7 +51,7 @@ def fast_norm(v: Union[Tuple[float, ...], ndarray]) -> float:
4951

5052
def fast_2d_point_in_simplex(
5153
point: Point, simplex: SimplexPoints, eps: float = 1e-8
52-
) -> bool:
54+
) -> Union[np.bool_, bool]:
5355
(p0x, p0y), (p1x, p1y), (p2x, p2y) = simplex
5456
px, py = point
5557

@@ -63,7 +65,9 @@ def fast_2d_point_in_simplex(
6365
return (t >= -eps) and (s + t <= 1 + eps)
6466

6567

66-
def point_in_simplex(point: Point, simplex: SimplexPoints, eps: float = 1e-8) -> bool:
68+
def point_in_simplex(
69+
point: Point, simplex: SimplexPoints, eps: float = 1e-8
70+
) -> Union[np.bool_, bool]:
6771
if len(point) == 2:
6872
return fast_2d_point_in_simplex(point, simplex, eps)
6973

@@ -322,7 +326,7 @@ class Triangulation:
322326
or more simplices in the
323327
"""
324328

325-
def __init__(self, coords: Sequence[Point]) -> None:
329+
def __init__(self, coords: Union[Sequence[Point], ndarray]) -> None:
326330
if not is_iterable_and_sized(coords):
327331
raise TypeError("Please provide a 2-dimensional list of points")
328332
coords = list(coords)
@@ -373,10 +377,12 @@ def add_simplex(self, simplex: Simplex) -> None:
373377
for vertex in simplex:
374378
self.vertex_to_simplices[vertex].add(simplex)
375379

376-
def get_vertices(self, indices: Sequence[int]) -> List[Optional[Point]]:
380+
def get_vertices(
381+
self, indices: Sequence[numbers.Integral]
382+
) -> List[Optional[Point]]:
377383
return [self.get_vertex(i) for i in indices]
378384

379-
def get_vertex(self, index: Optional[int]) -> Optional[Point]:
385+
def get_vertex(self, index: Optional[numbers.Integral]) -> Optional[Point]:
380386
if index is None:
381387
return None
382388
return self.vertices[index]
@@ -410,7 +416,7 @@ def get_reduced_simplex(
410416

411417
def point_in_simplex(
412418
self, point: Point, simplex: Simplex, eps: float = 1e-8
413-
) -> bool:
419+
) -> Union[np.bool_, bool]:
414420
vertices = self.get_vertices(simplex)
415421
return point_in_simplex(point, vertices, eps)
416422

@@ -616,7 +622,7 @@ def add_point(
616622
point: Point,
617623
simplex: Optional[Simplex] = None,
618624
transform: Optional[ndarray] = None,
619-
) -> Any:
625+
) -> Tuple[Set[Simplex], Set[Simplex]]:
620626
"""Add a new vertex and create simplices as appropriate.
621627
622628
Parameters

0 commit comments

Comments
 (0)