Skip to content

Commit e69ab2d

Browse files
authored
Merge pull request #324 from python-adaptive/average-learner-type-hints
AverageLearner type hints
2 parents 82245b2 + ce02311 commit e69ab2d

File tree

4 files changed

+60
-53
lines changed

4 files changed

+60
-53
lines changed

adaptive/learner/average_learner.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from math import sqrt
2+
from typing import Callable, Dict, List, Optional, Tuple
23

34
import cloudpickle
45
import numpy as np
56

67
from adaptive.learner.base_learner import BaseLearner
78
from adaptive.notebook_integration import ensure_holoviews
9+
from adaptive.types import Float, Real
810
from adaptive.utils import cache_latest
911

1012

@@ -33,7 +35,13 @@ class AverageLearner(BaseLearner):
3335
Number of evaluated points.
3436
"""
3537

36-
def __init__(self, function, atol=None, rtol=None, min_npoints=2):
38+
def __init__(
39+
self,
40+
function: Callable[[int], Real],
41+
atol: Optional[float] = None,
42+
rtol: Optional[float] = None,
43+
min_npoints: int = 2,
44+
) -> None:
3745
if atol is None and rtol is None:
3846
raise Exception("At least one of `atol` and `rtol` should be set.")
3947
if atol is None:
@@ -43,24 +51,24 @@ def __init__(self, function, atol=None, rtol=None, min_npoints=2):
4351

4452
self.data = {}
4553
self.pending_points = set()
46-
self.function = function
54+
self.function = function # type: ignore
4755
self.atol = atol
4856
self.rtol = rtol
4957
self.npoints = 0
5058
# Cannot estimate standard deviation with fewer than 2 points.
5159
self.min_npoints = max(min_npoints, 2)
52-
self.sum_f = 0
53-
self.sum_f_sq = 0
60+
self.sum_f: Real = 0.0
61+
self.sum_f_sq: Real = 0.0
5462

5563
@property
56-
def n_requested(self):
64+
def n_requested(self) -> int:
5765
return self.npoints + len(self.pending_points)
5866

5967
def to_numpy(self):
6068
"""Data as NumPy array of size (npoints, 2) with seeds and values."""
6169
return np.array(sorted(self.data.items()))
6270

63-
def ask(self, n, tell_pending=True):
71+
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[int], List[Float]]:
6472
points = list(range(self.n_requested, self.n_requested + n))
6573

6674
if any(p in self.data or p in self.pending_points for p in points):
@@ -77,7 +85,7 @@ def ask(self, n, tell_pending=True):
7785
self.tell_pending(p)
7886
return points, loss_improvements
7987

80-
def tell(self, n, value):
88+
def tell(self, n: int, value: Real) -> None:
8189
if n in self.data:
8290
# The point has already been added before.
8391
return
@@ -88,16 +96,16 @@ def tell(self, n, value):
8896
self.sum_f_sq += value ** 2
8997
self.npoints += 1
9098

91-
def tell_pending(self, n):
99+
def tell_pending(self, n: int) -> None:
92100
self.pending_points.add(n)
93101

94102
@property
95-
def mean(self):
103+
def mean(self) -> Float:
96104
"""The average of all values in `data`."""
97105
return self.sum_f / self.npoints
98106

99107
@property
100-
def std(self):
108+
def std(self) -> Float:
101109
"""The corrected sample standard deviation of the values
102110
in `data`."""
103111
n = self.npoints
@@ -110,7 +118,7 @@ def std(self):
110118
return sqrt(numerator / (n - 1))
111119

112120
@cache_latest
113-
def loss(self, real=True, *, n=None):
121+
def loss(self, real: bool = True, *, n=None) -> Float:
114122
if n is None:
115123
n = self.npoints if real else self.n_requested
116124
else:
@@ -120,11 +128,12 @@ def loss(self, real=True, *, n=None):
120128
standard_error = self.std / sqrt(n)
121129
aloss = standard_error / self.atol
122130
rloss = standard_error / self.rtol
123-
if self.mean != 0:
124-
rloss /= abs(self.mean)
131+
mean = self.mean
132+
if mean != 0:
133+
rloss /= abs(mean)
125134
return max(aloss, rloss)
126135

127-
def _loss_improvement(self, n):
136+
def _loss_improvement(self, n: int) -> Float:
128137
loss = self.loss()
129138
if np.isfinite(loss):
130139
return loss - self.loss(n=self.npoints + n)
@@ -150,10 +159,10 @@ def plot(self):
150159
vals = hv.Points(vals)
151160
return hv.operation.histogram(vals, num_bins=num_bins, dimension="y")
152161

153-
def _get_data(self):
162+
def _get_data(self) -> Tuple[Dict[int, Real], int, Real, Real]:
154163
return (self.data, self.npoints, self.sum_f, self.sum_f_sq)
155164

156-
def _set_data(self, data):
165+
def _set_data(self, data: Tuple[Dict[int, Real], int, Real, Real]) -> None:
157166
self.data, self.npoints, self.sum_f, self.sum_f_sq = data
158167

159168
def __getstate__(self):

adaptive/learner/average_learner1D.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,7 @@
33
from collections import defaultdict
44
from copy import deepcopy
55
from math import hypot
6-
from typing import (
7-
Callable,
8-
DefaultDict,
9-
Dict,
10-
List,
11-
Optional,
12-
Sequence,
13-
Set,
14-
Tuple,
15-
Union,
16-
)
6+
from typing import Callable, DefaultDict, Dict, List, Optional, Sequence, Set, Tuple
177

188
import numpy as np
199
import scipy.stats
@@ -22,9 +12,9 @@
2212

2313
from adaptive.learner.learner1D import Learner1D, _get_intervals
2414
from adaptive.notebook_integration import ensure_holoviews
15+
from adaptive.types import Real
2516

26-
number = Union[int, float, np.int_, np.float_]
27-
Point = Tuple[int, number]
17+
Point = Tuple[int, Real]
2818
Points = List[Point]
2919

3020
__all__: List[str] = ["AverageLearner1D"]
@@ -45,7 +35,7 @@ class AverageLearner1D(Learner1D):
4535
If not provided, then a default is used, which uses the scaled distance
4636
in the x-y plane as the loss. See the notes for more details
4737
of `adaptive.Learner1D` for more details.
48-
delta : float
38+
delta : float, optional, default 0.2
4939
This parameter controls the resampling condition. A point is resampled
5040
if its uncertainty is larger than delta times the smallest neighboring
5141
interval.
@@ -75,10 +65,10 @@ class AverageLearner1D(Learner1D):
7565

7666
def __init__(
7767
self,
78-
function: Callable[[Tuple[int, number]], number],
79-
bounds: Tuple[number, number],
68+
function: Callable[[Tuple[int, Real]], Real],
69+
bounds: Tuple[Real, Real],
8070
loss_per_interval: Optional[
81-
Callable[[Sequence[number], Sequence[number]], float]
71+
Callable[[Sequence[Real], Sequence[Real]], float]
8272
] = None,
8373
delta: float = 0.2,
8474
alpha: float = 0.005,
@@ -115,15 +105,15 @@ def __init__(
115105
self._number_samples = SortedDict()
116106
# This set contains the points x that have less than min_samples
117107
# samples or less than a (neighbor_sampling*100)% of their neighbors
118-
self._undersampled_points: Set[number] = set()
108+
self._undersampled_points: Set[Real] = set()
119109
# Contains the error in the estimate of the
120110
# mean at each point x in the form {x0: error(x0), ...}
121-
self.error: ItemSortedDict[number, float] = decreasing_dict()
111+
self.error: ItemSortedDict[Real, float] = decreasing_dict()
122112
#  Distance between two neighboring points in the
123113
# form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
124-
self._distances: ItemSortedDict[number, float] = decreasing_dict()
114+
self._distances: ItemSortedDict[Real, float] = decreasing_dict()
125115
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
126-
self.rescaled_error: ItemSortedDict[number, float] = decreasing_dict()
116+
self.rescaled_error: ItemSortedDict[Real, float] = decreasing_dict()
127117

128118
@property
129119
def nsamples(self) -> int:
@@ -165,7 +155,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[Points, List[float]]:
165155

166156
return points, loss_improvements
167157

168-
def _ask_for_more_samples(self, x: number, n: int) -> Tuple[Points, List[float]]:
158+
def _ask_for_more_samples(self, x: Real, n: int) -> Tuple[Points, List[float]]:
169159
"""When asking for n points, the learner returns n times an existing point
170160
to be resampled, since in general n << min_samples and this point will
171161
need to be resampled many more times"""
@@ -200,7 +190,7 @@ def tell_pending(self, seed_x: Point) -> None:
200190
self._update_neighbors(x, self.neighbors_combined)
201191
self._update_losses(x, real=False)
202192

203-
def tell(self, seed_x: Point, y: number) -> None:
193+
def tell(self, seed_x: Point, y: Real) -> None:
204194
seed, x = seed_x
205195
if y is None:
206196
raise TypeError(
@@ -216,7 +206,7 @@ def tell(self, seed_x: Point, y: number) -> None:
216206
self._update_data_structures(seed_x, y, "resampled")
217207
self.pending_points.discard(seed_x)
218208

219-
def _update_rescaled_error_in_mean(self, x: number, point_type: str) -> None:
209+
def _update_rescaled_error_in_mean(self, x: Real, point_type: str) -> None:
220210
"""Updates ``self.rescaled_error``.
221211
222212
Parameters
@@ -253,17 +243,15 @@ def _update_rescaled_error_in_mean(self, x: number, point_type: str) -> None:
253243
norm = min(d_left, d_right)
254244
self.rescaled_error[x] = self.error[x] / norm
255245

256-
def _update_data(self, x: number, y: number, point_type: str) -> None:
246+
def _update_data(self, x: Real, y: Real, point_type: str) -> None:
257247
if point_type == "new":
258248
self.data[x] = y
259249
elif point_type == "resampled":
260250
n = len(self._data_samples[x])
261251
new_average = self.data[x] * n / (n + 1) + y / (n + 1)
262252
self.data[x] = new_average
263253

264-
def _update_data_structures(
265-
self, seed_x: Point, y: number, point_type: str
266-
) -> None:
254+
def _update_data_structures(self, seed_x: Point, y: Real, point_type: str) -> None:
267255
seed, x = seed_x
268256
if point_type == "new":
269257
self._data_samples[x] = {seed: y}
@@ -331,15 +319,15 @@ def _update_data_structures(
331319
self._update_interpolated_loss_in_interval(*interval)
332320
self._oldscale = deepcopy(self._scale)
333321

334-
def _update_distances(self, x: number) -> None:
322+
def _update_distances(self, x: Real) -> None:
335323
x_left, x_right = self.neighbors[x]
336324
y = self.data[x]
337325
if x_left is not None:
338326
self._distances[x_left] = hypot((x - x_left), (y - self.data[x_left]))
339327
if x_right is not None:
340328
self._distances[x] = hypot((x_right - x), (self.data[x_right] - y))
341329

342-
def _update_losses_resampling(self, x: number, real=True) -> None:
330+
def _update_losses_resampling(self, x: Real, real=True) -> None:
343331
"""Update all losses that depend on x, whenever the new point is a re-sampled point."""
344332
# (x_left, x_right) are the "real" neighbors of 'x'.
345333
x_left, x_right = self._find_neighbors(x, self.neighbors)
@@ -368,12 +356,12 @@ def _update_losses_resampling(self, x: number, real=True) -> None:
368356
if (b is not None) and right_loss_is_unknown:
369357
self.losses_combined[x, b] = float("inf")
370358

371-
def _calc_error_in_mean(self, ys: Sequence[number], y_avg: number, n: int) -> float:
359+
def _calc_error_in_mean(self, ys: Sequence[Real], y_avg: Real, n: int) -> float:
372360
variance_in_mean = sum((y - y_avg) ** 2 for y in ys) / (n - 1)
373361
t_student = scipy.stats.t.ppf(1 - self.alpha, df=n - 1)
374362
return t_student * (variance_in_mean / n) ** 0.5
375363

376-
def tell_many(self, xs: Points, ys: Sequence[number]) -> None:
364+
def tell_many(self, xs: Points, ys: Sequence[Real]) -> None:
377365
# Check that all x are within the bounds
378366
# TODO: remove this requirement, all other learners add the data
379367
# but ignore it going forward.
@@ -384,7 +372,7 @@ def tell_many(self, xs: Points, ys: Sequence[number]) -> None:
384372
)
385373

386374
# Create a mapping of points to a list of samples
387-
mapping: DefaultDict[number, DefaultDict[int, number]] = defaultdict(
375+
mapping: DefaultDict[Real, DefaultDict[int, Real]] = defaultdict(
388376
lambda: defaultdict(dict)
389377
)
390378
for (seed, x), y in zip(xs, ys):
@@ -400,14 +388,14 @@ def tell_many(self, xs: Points, ys: Sequence[number]) -> None:
400388
# simultaneously, before we move on to a new x
401389
self.tell_many_at_point(x, seed_y_mapping)
402390

403-
def tell_many_at_point(self, x: number, seed_y_mapping: Dict[int, number]) -> None:
391+
def tell_many_at_point(self, x: Real, seed_y_mapping: Dict[int, Real]) -> None:
404392
"""Tell the learner about many samples at a certain location x.
405393
406394
Parameters
407395
----------
408396
x : float
409397
Value from the function domain.
410-
seed_y_mapping : Dict[int, number]
398+
seed_y_mapping : Dict[int, Real]
411399
Dictionary of ``seed`` -> ``y`` at ``x``.
412400
"""
413401
# Check x is within the bounds
@@ -456,10 +444,10 @@ def tell_many_at_point(self, x: number, seed_y_mapping: Dict[int, number]) -> No
456444
self._update_interpolated_loss_in_interval(*interval)
457445
self._oldscale = deepcopy(self._scale)
458446

459-
def _get_data(self) -> SortedDict[number, number]:
447+
def _get_data(self) -> SortedDict[Real, Real]:
460448
return self._data_samples
461449

462-
def _set_data(self, data: SortedDict[number, number]) -> None:
450+
def _set_data(self, data: SortedDict[Real, Real]) -> None:
463451
if data:
464452
for x, samples in data.items():
465453
self.tell_many_at_point(x, samples)

adaptive/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from typing import Union
2+
3+
import numpy as np
4+
5+
Float = Union[float, np.float_]
6+
Int = Union[int, np.int_]
7+
Real = Union[Float, Int]

tox.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,6 @@ exclude = .git, .tox, __pycache__, dist
5555

5656
[isort]
5757
profile=black
58+
59+
[mypy]
60+
ignore_missing_imports = True

0 commit comments

Comments
 (0)