Skip to content

Commit 4e43f39

Browse files
committed
improve type-hints in average_learner1D.py
1 parent cf9b6fe commit 4e43f39

File tree

1 file changed

+53
-36
lines changed

1 file changed

+53
-36
lines changed

adaptive/learner/average_learner1D.py

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,17 @@
22
from collections import defaultdict
33
from copy import deepcopy
44
from math import hypot
5-
from numbers import Number
6-
from typing import Dict, List, Sequence, Tuple, Union
5+
from typing import (
6+
Callable,
7+
DefaultDict,
8+
Dict,
9+
List,
10+
Optional,
11+
Sequence,
12+
Set,
13+
Tuple,
14+
Union,
15+
)
716

817
import numpy as np
918
import scipy.stats
@@ -13,9 +22,13 @@
1322
from adaptive.learner.learner1D import Learner1D, _get_intervals
1423
from adaptive.notebook_integration import ensure_holoviews
1524

16-
Point = Tuple[int, Number]
25+
number = Union[int, float, np.int_, np.float_]
26+
27+
Point = Tuple[int, number]
1728
Points = List[Point]
18-
Value = Union[Number, Sequence[Number]]
29+
Value = Union[number, Sequence[number], np.ndarray]
30+
31+
__all__ = ["AverageLearner1D"]
1932

2033

2134
class AverageLearner1D(Learner1D):
@@ -37,21 +50,21 @@ class AverageLearner1D(Learner1D):
3750
This parameter controls the resampling condition. A point is resampled
3851
if its uncertainty is larger than delta times the smallest neighboring
3952
interval.
40-
We strongly recommend 0 < delta <= 1.
41-
alpha : float (0 < alpha < 1)
53+
We strongly recommend ``0 < delta <= 1``.
54+
alpha : float (0 < alpha < 1), default 0.005
4255
The true value of the function at x is within the confidence interval
43-
[self.data[x] - self.error[x], self.data[x] +
44-
self.error[x]] with probability 1-2*alpha.
45-
We recommend to keep alpha=0.005.
46-
neighbor_sampling : float (0 < neighbor_sampling <= 1)
56+
``[self.data[x] - self.error[x], self.data[x] + self.error[x]]`` with
57+
probability ``1-2*alpha``.
58+
We recommend to keep ``alpha=0.005``.
59+
neighbor_sampling : float (0 < neighbor_sampling <= 1), default 0.3
4760
Each new point is initially sampled at least a (neighbor_sampling*100)%
4861
of the average number of samples of its neighbors.
49-
min_samples : int (min_samples > 0)
62+
min_samples : int (min_samples > 0), default 50
5063
Minimum number of samples at each point x. Each new point is initially
5164
sampled at least min_samples times.
52-
max_samples : int (min_samples < max_samples)
65+
max_samples : int (min_samples < max_samples), default np.inf
5366
Maximum number of samples at each point x.
54-
min_error : float (min_error >= 0)
67+
min_error : float (min_error >= 0), default 0
5568
Minimum size of the confidence intervals. The true value of the
5669
function at x is within the confidence interval [self.data[x] -
5770
self.error[x], self.data[x] + self.error[x]] with
@@ -63,15 +76,17 @@ class AverageLearner1D(Learner1D):
6376

6477
def __init__(
6578
self,
66-
function,
67-
bounds,
68-
loss_per_interval=None,
69-
delta=0.2,
70-
alpha=0.005,
71-
neighbor_sampling=0.3,
72-
min_samples=50,
73-
max_samples=np.inf,
74-
min_error=0,
79+
function: Callable[[Tuple[int, number]], Value],
80+
bounds: Tuple[number, number],
81+
loss_per_interval: Optional[
82+
Callable[[Sequence[number], Sequence[number]], float]
83+
] = None,
84+
delta: float = 0.2,
85+
alpha: float = 0.005,
86+
neighbor_sampling: float = 0.3,
87+
min_samples: int = 50,
88+
max_samples: int = np.inf,
89+
min_error: float = 0,
7590
):
7691
if not (0 < delta <= 1):
7792
raise ValueError("Learner requires 0 < delta <= 1.")
@@ -101,15 +116,15 @@ def __init__(
101116
self._number_samples = SortedDict()
102117
# This set contains the points x that have less than min_samples
103118
# samples or less than a (neighbor_sampling*100)% of their neighbors
104-
self._undersampled_points = set()
119+
self._undersampled_points: Set[number] = set()
105120
# Contains the error in the estimate of the
106121
# mean at each point x in the form {x0: error(x0), ...}
107-
self.error = decreasing_dict()
122+
self.error: ItemSortedDict[number, float] = decreasing_dict()
108123
#  Distance between two neighboring points in the
109124
# form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
110-
self._distances = decreasing_dict()
125+
self._distances: ItemSortedDict[number, float] = decreasing_dict()
111126
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
112-
self.rescaled_error = decreasing_dict()
127+
self.rescaled_error: ItemSortedDict[number, float] = decreasing_dict()
113128

114129
@property
115130
def nsamples(self) -> int:
@@ -151,7 +166,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[Points, List[float]]:
151166

152167
return points, loss_improvements
153168

154-
def _ask_for_more_samples(self, x: Number, n: int) -> Tuple[Points, List[float]]:
169+
def _ask_for_more_samples(self, x: number, n: int) -> Tuple[Points, List[float]]:
155170
"""When asking for n points, the learner returns n times an existing point
156171
to be resampled, since in general n << min_samples and this point will
157172
need to be resampled many more times"""
@@ -205,7 +220,7 @@ def tell(self, seed_x: Point, y: Value) -> None:
205220
self._update_data_structures(seed_x, y, "resampled")
206221
self.pending_points.discard(seed_x)
207222

208-
def _update_rescaled_error_in_mean(self, x: Number, point_type: str) -> None:
223+
def _update_rescaled_error_in_mean(self, x: number, point_type: str) -> None:
209224
"""Updates ``self.rescaled_error``.
210225
211226
Parameters
@@ -242,7 +257,7 @@ def _update_rescaled_error_in_mean(self, x: Number, point_type: str) -> None:
242257
norm = min(d_left, d_right)
243258
self.rescaled_error[x] = self.error[x] / norm
244259

245-
def _update_data(self, x: Number, y: Value, point_type: str) -> None:
260+
def _update_data(self, x: number, y: Value, point_type: str) -> None:
246261
if point_type == "new":
247262
self.data[x] = y
248263
elif point_type == "resampled":
@@ -318,15 +333,15 @@ def _update_data_structures(self, seed_x: Point, y: Value, point_type: str) -> N
318333
self._update_interpolated_loss_in_interval(*interval)
319334
self._oldscale = deepcopy(self._scale)
320335

321-
def _update_distances(self, x: Number) -> None:
336+
def _update_distances(self, x: number) -> None:
322337
x_left, x_right = self.neighbors[x]
323338
y = self.data[x]
324339
if x_left is not None:
325340
self._distances[x_left] = hypot((x - x_left), (y - self.data[x_left]))
326341
if x_right is not None:
327342
self._distances[x] = hypot((x_right - x), (self.data[x_right] - y))
328343

329-
def _update_losses_resampling(self, x: Number, real=True) -> None:
344+
def _update_losses_resampling(self, x: number, real=True) -> None:
330345
"""Update all losses that depend on x, whenever the new point is a re-sampled point."""
331346
# (x_left, x_right) are the "real" neighbors of 'x'.
332347
x_left, x_right = self._find_neighbors(x, self.neighbors)
@@ -371,7 +386,9 @@ def tell_many(self, xs: Points, ys: Sequence[Value]) -> None:
371386
)
372387

373388
# Create a mapping of points to a list of samples
374-
mapping = defaultdict(lambda: defaultdict(dict))
389+
mapping: DefaultDict[number, DefaultDict[int, Value]] = defaultdict(
390+
lambda: defaultdict(dict)
391+
)
375392
for (seed, x), y in zip(xs, ys):
376393
mapping[x][seed] = y
377394

@@ -411,7 +428,7 @@ def tell_many_at_point(self, x: float, seed_y_mapping: Dict[int, Value]) -> None
411428
self._update_data(x, y, "new")
412429
self._update_data_structures((seed, x), y, "new")
413430

414-
ys = list(seed_y_mapping.values()) # cast to list *and* make a copy
431+
ys = np.array(list(seed_y_mapping.values()))
415432

416433
# If x is not a new point or if there were more than 1 sample in ys:
417434
if len(ys) > 0:
@@ -441,10 +458,10 @@ def tell_many_at_point(self, x: float, seed_y_mapping: Dict[int, Value]) -> None
441458
self._update_interpolated_loss_in_interval(*interval)
442459
self._oldscale = deepcopy(self._scale)
443460

444-
def _get_data(self) -> SortedDict:
461+
def _get_data(self) -> SortedDict[number, Value]:
445462
return self._data_samples
446463

447-
def _set_data(self, data: SortedDict) -> None:
464+
def _set_data(self, data: SortedDict[number, Value]) -> None:
448465
if data:
449466
for x, samples in data.items():
450467
self.tell_many_at_point(x, samples)
@@ -478,7 +495,7 @@ def plot(self):
478495
return p.redim(x=dict(range=plot_bounds))
479496

480497

481-
def decreasing_dict():
498+
def decreasing_dict() -> ItemSortedDict:
482499
"""This initialization orders the dictionary from large to small values"""
483500

484501
def sorting_rule(key, value):

0 commit comments

Comments
 (0)