Skip to content

Commit e92c4fb

Browse files
committed
type hint fixes for adaptive/learner/learner2D.py
1 parent 3034acb commit e92c4fb

File tree

1 file changed

+12
-36
lines changed

1 file changed

+12
-36
lines changed

adaptive/learner/learner2D.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
import warnings
33
from collections import OrderedDict
44
from copy import copy
5-
from functools import partial
65
from math import sqrt
7-
from typing import Any, Callable, List, Optional, Tuple, Union
6+
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
87

98
import numpy as np
109
from scipy import interpolate
@@ -107,7 +106,9 @@ def uniform_loss(ip: LinearNDInterpolator) -> np.ndarray:
107106
return np.sqrt(areas(ip))
108107

109108

110-
def resolution_loss_function(min_distance: int = 0, max_distance: int = 1) -> Callable:
109+
def resolution_loss_function(
110+
min_distance: float = 0, max_distance: float = 1
111+
) -> Callable:
111112
"""Loss function that is similar to the `default_loss` function, but you
112113
can set the maximimum and minimum size of a triangle.
113114
@@ -353,10 +354,8 @@ class Learner2D(BaseLearner):
353354

354355
def __init__(
355356
self,
356-
function: Union[partial, Callable],
357-
bounds: Union[
358-
List[Tuple[int, int]], Tuple[Tuple[int, int], Tuple[int, int]], np.ndarray
359-
],
357+
function: Callable,
358+
bounds: Tuple[Tuple[int, int], Tuple[int, int]],
360359
loss_per_triangle: Optional[Callable] = None,
361360
) -> None:
362361
self.ndim = len(bounds)
@@ -462,7 +461,7 @@ def _data_in_bounds(self) -> Tuple[np.ndarray, np.ndarray]:
462461
return points[inds], values[inds].reshape(-1, self.vdim)
463462
return np.zeros((0, 2)), np.zeros((0, self.vdim), dtype=float)
464463

465-
def _data_interp(self) -> Any:
464+
def _data_interp(self) -> Tuple[np.ndarray, np.ndarray]:
466465
if self.pending_points:
467466
points = list(self.pending_points)
468467
if self.bounds_are_done:
@@ -493,7 +492,7 @@ def ip(self):
493492
)
494493
return self.interpolator(scaled=True)
495494

496-
def interpolator(self, *, scaled=False) -> LinearNDInterpolator:
495+
def interpolator(self, *, scaled: bool = False) -> LinearNDInterpolator:
497496
"""A `scipy.interpolate.LinearNDInterpolator` instance
498497
containing the learner's data.
499498
@@ -534,28 +533,13 @@ def _interpolator_combined(self) -> LinearNDInterpolator:
534533
self._ip_combined = interpolate.LinearNDInterpolator(points, values)
535534
return self._ip_combined
536535

537-
def inside_bounds(
538-
self,
539-
xy: Union[
540-
Tuple[int, int],
541-
Tuple[float, float],
542-
Tuple[float, float],
543-
Tuple[float, float],
544-
],
545-
) -> Union[bool, np.bool_]:
536+
def inside_bounds(self, xy: Tuple[float, float],) -> Union[bool, np.bool_]:
546537
x, y = xy
547538
(xmin, xmax), (ymin, ymax) = self.bounds
548539
return xmin <= x <= xmax and ymin <= y <= ymax
549540

550541
def tell(
551-
self,
552-
point: Union[
553-
Tuple[int, int],
554-
Tuple[float, float],
555-
Tuple[float, float],
556-
Tuple[float, float],
557-
],
558-
value: Union[List[int], float, float],
542+
self, point: Tuple[float, float], value: Union[float, Iterable[float]],
559543
) -> None:
560544
point = tuple(point)
561545
self.data[point] = value
@@ -565,15 +549,7 @@ def tell(
565549
self._ip = None
566550
self._stack.pop(point, None)
567551

568-
def tell_pending(
569-
self,
570-
point: Union[
571-
Tuple[int, int],
572-
Tuple[float, float],
573-
Tuple[float, float],
574-
Tuple[float, float],
575-
],
576-
) -> None:
552+
def tell_pending(self, point: Tuple[float, float],) -> None:
577553
point = tuple(point)
578554
if not self.inside_bounds(point):
579555
return
@@ -622,7 +598,7 @@ def _fill_stack(
622598

623599
return points_new, losses_new
624600

625-
def ask(self, n: int, tell_pending: bool = True) -> Any:
601+
def ask(self, n: int, tell_pending: bool = True) -> Tuple[np.ndarray, np.ndarray]:
626602
# Even if tell_pending is False we add the point such that _fill_stack
627603
# will return new points, later we remove these points if needed.
628604
points = list(self._stack.keys())

0 commit comments

Comments
 (0)