Skip to content

Commit 82245b2

Browse files
committed
fix mypy errors
1 parent 4e43f39 commit 82245b2

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

adaptive/learner/average_learner1D.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import sys
23
from collections import defaultdict
34
from copy import deepcopy
45
from math import hypot
@@ -23,12 +24,10 @@
2324
from adaptive.notebook_integration import ensure_holoviews
2425

2526
number = Union[int, float, np.int_, np.float_]
26-
2727
Point = Tuple[int, number]
2828
Points = List[Point]
29-
Value = Union[number, Sequence[number], np.ndarray]
3029

31-
__all__ = ["AverageLearner1D"]
30+
__all__: List[str] = ["AverageLearner1D"]
3231

3332

3433
class AverageLearner1D(Learner1D):
@@ -76,7 +75,7 @@ class AverageLearner1D(Learner1D):
7675

7776
def __init__(
7877
self,
79-
function: Callable[[Tuple[int, number]], Value],
78+
function: Callable[[Tuple[int, number]], number],
8079
bounds: Tuple[number, number],
8180
loss_per_interval: Optional[
8281
Callable[[Sequence[number], Sequence[number]], float]
@@ -85,7 +84,7 @@ def __init__(
8584
alpha: float = 0.005,
8685
neighbor_sampling: float = 0.3,
8786
min_samples: int = 50,
88-
max_samples: int = np.inf,
87+
max_samples: int = sys.maxsize,
8988
min_error: float = 0,
9089
):
9190
if not (0 < delta <= 1):
@@ -201,16 +200,13 @@ def tell_pending(self, seed_x: Point) -> None:
201200
self._update_neighbors(x, self.neighbors_combined)
202201
self._update_losses(x, real=False)
203202

204-
def tell(self, seed_x: Point, y: Value) -> None:
203+
def tell(self, seed_x: Point, y: number) -> None:
205204
seed, x = seed_x
206205
if y is None:
207206
raise TypeError(
208207
"Y-value may not be None, use learner.tell_pending(x)"
209208
"to indicate that this value is currently being calculated"
210209
)
211-
# either it is a float/int, if not, try casting to a np.array
212-
if not isinstance(y, (float, int)):
213-
y = np.asarray(y, dtype=float)
214210

215211
if x not in self.data:
216212
self._update_data(x, y, "new")
@@ -257,15 +253,17 @@ def _update_rescaled_error_in_mean(self, x: number, point_type: str) -> None:
257253
norm = min(d_left, d_right)
258254
self.rescaled_error[x] = self.error[x] / norm
259255

260-
def _update_data(self, x: number, y: Value, point_type: str) -> None:
256+
def _update_data(self, x: number, y: number, point_type: str) -> None:
261257
if point_type == "new":
262258
self.data[x] = y
263259
elif point_type == "resampled":
264260
n = len(self._data_samples[x])
265261
new_average = self.data[x] * n / (n + 1) + y / (n + 1)
266262
self.data[x] = new_average
267263

268-
def _update_data_structures(self, seed_x: Point, y: Value, point_type: str) -> None:
264+
def _update_data_structures(
265+
self, seed_x: Point, y: number, point_type: str
266+
) -> None:
269267
seed, x = seed_x
270268
if point_type == "new":
271269
self._data_samples[x] = {seed: y}
@@ -370,12 +368,12 @@ def _update_losses_resampling(self, x: number, real=True) -> None:
370368
if (b is not None) and right_loss_is_unknown:
371369
self.losses_combined[x, b] = float("inf")
372370

373-
def _calc_error_in_mean(self, ys: Sequence[Value], y_avg: Value, n: int) -> float:
371+
def _calc_error_in_mean(self, ys: Sequence[number], y_avg: number, n: int) -> float:
374372
variance_in_mean = sum((y - y_avg) ** 2 for y in ys) / (n - 1)
375373
t_student = scipy.stats.t.ppf(1 - self.alpha, df=n - 1)
376374
return t_student * (variance_in_mean / n) ** 0.5
377375

378-
def tell_many(self, xs: Points, ys: Sequence[Value]) -> None:
376+
def tell_many(self, xs: Points, ys: Sequence[number]) -> None:
379377
# Check that all x are within the bounds
380378
# TODO: remove this requirement, all other learners add the data
381379
# but ignore it going forward.
@@ -386,7 +384,7 @@ def tell_many(self, xs: Points, ys: Sequence[Value]) -> None:
386384
)
387385

388386
# Create a mapping of points to a list of samples
389-
mapping: DefaultDict[number, DefaultDict[int, Value]] = defaultdict(
387+
mapping: DefaultDict[number, DefaultDict[int, number]] = defaultdict(
390388
lambda: defaultdict(dict)
391389
)
392390
for (seed, x), y in zip(xs, ys):
@@ -402,14 +400,14 @@ def tell_many(self, xs: Points, ys: Sequence[Value]) -> None:
402400
# simultaneously, before we move on to a new x
403401
self.tell_many_at_point(x, seed_y_mapping)
404402

405-
def tell_many_at_point(self, x: float, seed_y_mapping: Dict[int, Value]) -> None:
403+
def tell_many_at_point(self, x: number, seed_y_mapping: Dict[int, number]) -> None:
406404
"""Tell the learner about many samples at a certain location x.
407405
408406
Parameters
409407
----------
410408
x : float
411409
Value from the function domain.
412-
seed_y_mapping : Dict[int, Value]
410+
seed_y_mapping : Dict[int, number]
413411
Dictionary of ``seed`` -> ``y`` at ``x``.
414412
"""
415413
# Check x is within the bounds
@@ -458,10 +456,10 @@ def tell_many_at_point(self, x: float, seed_y_mapping: Dict[int, Value]) -> None
458456
self._update_interpolated_loss_in_interval(*interval)
459457
self._oldscale = deepcopy(self._scale)
460458

461-
def _get_data(self) -> SortedDict[number, Value]:
459+
def _get_data(self) -> SortedDict[number, number]:
462460
return self._data_samples
463461

464-
def _set_data(self, data: SortedDict[number, Value]) -> None:
462+
def _set_data(self, data: SortedDict[number, number]) -> None:
465463
if data:
466464
for x, samples in data.items():
467465
self.tell_many_at_point(x, samples)

0 commit comments

Comments
 (0)