Skip to content

Commit e0809ae

Browse files
authored
Add mypy to pre-commit and fix all current typing issues (#414)
* Remove _RequireAttrsABCMeta metaclass and replace with simple check * add mypy * Fix some typing issues * Fix some typing issues * Fix some typing issues * fix all Runner type issues * fix all DataSaver type issues * fix all IntegratorLearner type issues * fix all SequenceLearner type issues * some fixes * some fixes * some fixes * Fix multiple issues * Fix all mypy issues * Make data a dict * make BaseLearner a ABC * remove BaseLearner._check_required_attributes() * remove unused deps * Wrap in TYPE_CHECKING * pin ipython * pin ipython * Add NotImplemented methods * remove unused import
1 parent 82ed0a4 commit e0809ae

28 files changed

+482
-269
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,8 @@ repos:
2525
- id: nbqa
2626
args: ["ruff", "--fix", "--ignore=E402,B018,F704"]
2727
additional_dependencies: [jupytext, ruff]
28+
- repo: https://github.com/pre-commit/mirrors-mypy
29+
rev: "v1.2.0"
30+
hooks:
31+
- id: mypy
32+
exclude: ipynb_filter.py|docs/source/conf.py

adaptive/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,5 @@
5454
__all__.append("SKOptLearner")
5555

5656
# to avoid confusion with `notebook_extension` and `__version__`
57-
del _version # noqa: F821
58-
del notebook_integration # noqa: F821
57+
del _version # type: ignore[name-defined] # noqa: F821
58+
del notebook_integration # type: ignore[name-defined] # noqa: F821

adaptive/_version.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# This file is part of 'miniver': https://github.com/jbweston/miniver
22
#
3+
from __future__ import annotations
4+
35
import os
46
import subprocess
57
from collections import namedtuple
@@ -10,7 +12,7 @@
1012
Version = namedtuple("Version", ("release", "dev", "labels"))
1113

1214
# No public API
13-
__all__ = []
15+
__all__: list[str] = []
1416

1517
package_root = os.path.dirname(os.path.realpath(__file__))
1618
package_name = os.path.basename(package_root)

adaptive/learner/average_learner.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
from __future__ import annotations
22

33
from math import sqrt
4-
from numbers import Integral as Int
5-
from numbers import Real
64
from typing import Callable
75

86
import cloudpickle
97
import numpy as np
108

119
from adaptive.learner.base_learner import BaseLearner
1210
from adaptive.notebook_integration import ensure_holoviews
13-
from adaptive.types import Float
11+
from adaptive.types import Float, Int, Real
1412
from adaptive.utils import (
1513
assign_defaults,
1614
cache_latest,
@@ -75,7 +73,6 @@ def __init__(
7573
self.min_npoints = max(min_npoints, 2)
7674
self.sum_f: Real = 0.0
7775
self.sum_f_sq: Real = 0.0
78-
self._check_required_attributes()
7976

8077
def new(self) -> AverageLearner:
8178
"""Create a copy of `~adaptive.AverageLearner` without the data."""
@@ -89,7 +86,7 @@ def to_numpy(self):
8986
"""Data as NumPy array of size (npoints, 2) with seeds and values."""
9087
return np.array(sorted(self.data.items()))
9188

92-
def to_dataframe(
89+
def to_dataframe( # type: ignore[override]
9390
self,
9491
with_default_function_args: bool = True,
9592
function_prefix: str = "function.",
@@ -129,7 +126,7 @@ def to_dataframe(
129126
assign_defaults(self.function, df, function_prefix)
130127
return df
131128

132-
def load_dataframe(
129+
def load_dataframe( # type: ignore[override]
133130
self,
134131
df: pandas.DataFrame,
135132
with_default_function_args: bool = True,

adaptive/learner/average_learner1D.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from collections import defaultdict
66
from copy import deepcopy
77
from math import hypot
8-
from numbers import Integral as Int
9-
from numbers import Real
108
from typing import Callable, DefaultDict, Iterable, List, Sequence, Tuple
119

1210
import numpy as np
@@ -16,6 +14,7 @@
1614

1715
from adaptive.learner.learner1D import Learner1D, _get_intervals
1816
from adaptive.notebook_integration import ensure_holoviews
17+
from adaptive.types import Int, Real
1918
from adaptive.utils import assign_defaults, partial_function_from_dataframe
2019

2120
try:
@@ -99,7 +98,7 @@ def __init__(
9998
if min_samples > max_samples:
10099
raise ValueError("max_samples should be larger than min_samples.")
101100

102-
super().__init__(function, bounds, loss_per_interval)
101+
super().__init__(function, bounds, loss_per_interval) # type: ignore[arg-type]
103102

104103
self.delta = delta
105104
self.alpha = alpha
@@ -110,7 +109,7 @@ def __init__(
110109

111110
# Contains all samples f(x) for each
112111
# point x in the form {x0: {0: f_0(x0), 1: f_1(x0), ...}, ...}
113-
self._data_samples = SortedDict()
112+
self._data_samples: SortedDict[float, dict[int, Real]] = SortedDict()
114113
# Contains the number of samples taken
115114
# at each point x in the form {x0: n0, x1: n1, ...}
116115
self._number_samples = SortedDict()
@@ -124,15 +123,14 @@ def __init__(
124123
# form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
125124
self._distances: dict[Real, float] = decreasing_dict()
126125
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
127-
self.rescaled_error: dict[Real, float] = decreasing_dict()
128-
self._check_required_attributes()
126+
self.rescaled_error: ItemSortedDict[Real, float] = decreasing_dict()
129127

130128
def new(self) -> AverageLearner1D:
131129
"""Create a copy of `~adaptive.AverageLearner1D` without the data."""
132130
return AverageLearner1D(
133131
self.function,
134132
self.bounds,
135-
self.loss_per_interval,
133+
self.loss_per_interval, # type: ignore[arg-type]
136134
self.delta,
137135
self.alpha,
138136
self.neighbor_sampling,
@@ -164,7 +162,7 @@ def to_numpy(self, mean: bool = False) -> np.ndarray:
164162
]
165163
)
166164

167-
def to_dataframe(
165+
def to_dataframe( # type: ignore[override]
168166
self,
169167
mean: bool = False,
170168
with_default_function_args: bool = True,
@@ -202,10 +200,10 @@ def to_dataframe(
202200
if not with_pandas:
203201
raise ImportError("pandas is not installed.")
204202
if mean:
205-
data = sorted(self.data.items())
203+
data: list[tuple[Real, Real]] = sorted(self.data.items())
206204
columns = [x_name, y_name]
207205
else:
208-
data = [
206+
data: list[tuple[int, Real, Real]] = [ # type: ignore[no-redef]
209207
(seed, x, y)
210208
for x, seed_y in sorted(self._data_samples.items())
211209
for seed, y in sorted(seed_y.items())
@@ -218,7 +216,7 @@ def to_dataframe(
218216
assign_defaults(self.function, df, function_prefix)
219217
return df
220218

221-
def load_dataframe(
219+
def load_dataframe( # type: ignore[override]
222220
self,
223221
df: pandas.DataFrame,
224222
with_default_function_args: bool = True,
@@ -258,7 +256,7 @@ def load_dataframe(
258256
self.function, df, function_prefix
259257
)
260258

261-
def ask(self, n: int, tell_pending: bool = True) -> tuple[Points, list[float]]:
259+
def ask(self, n: int, tell_pending: bool = True) -> tuple[Points, list[float]]: # type: ignore[override]
262260
"""Return 'n' points that are expected to maximally reduce the loss."""
263261
# If some point is undersampled, resample it
264262
if len(self._undersampled_points):
@@ -311,18 +309,18 @@ def _ask_for_new_point(self, n: int) -> tuple[Points, list[float]]:
311309
new point, since in general n << min_samples and this point will need
312310
to be resampled many more times"""
313311
points, (loss_improvement,) = self._ask_points_without_adding(1)
314-
points = [(seed, x) for seed, x in zip(range(n), n * points)]
312+
seed_points = [(seed, x) for seed, x in zip(range(n), n * points)]
315313
loss_improvements = [loss_improvement / n] * n
316-
return points, loss_improvements
314+
return seed_points, loss_improvements # type: ignore[return-value]
317315

318-
def tell_pending(self, seed_x: Point) -> None:
316+
def tell_pending(self, seed_x: Point) -> None: # type: ignore[override]
319317
_, x = seed_x
320318
self.pending_points.add(seed_x)
321319
if x not in self.data:
322320
self._update_neighbors(x, self.neighbors_combined)
323321
self._update_losses(x, real=False)
324322

325-
def tell(self, seed_x: Point, y: Real) -> None:
323+
def tell(self, seed_x: Point, y: Real) -> None: # type: ignore[override]
326324
seed, x = seed_x
327325
if y is None:
328326
raise TypeError(
@@ -493,7 +491,7 @@ def _calc_error_in_mean(self, ys: Iterable[Real], y_avg: Real, n: int) -> float:
493491
t_student = scipy.stats.t.ppf(1 - self.alpha, df=n - 1)
494492
return t_student * (variance_in_mean / n) ** 0.5
495493

496-
def tell_many(
494+
def tell_many( # type: ignore[override]
497495
self, xs: Points | np.ndarray, ys: Sequence[Real] | np.ndarray
498496
) -> None:
499497
# Check that all x are within the bounds
@@ -578,10 +576,10 @@ def tell_many_at_point(self, x: Real, seed_y_mapping: dict[int, Real]) -> None:
578576
self._update_interpolated_loss_in_interval(*interval)
579577
self._oldscale = deepcopy(self._scale)
580578

581-
def _get_data(self) -> dict[Real, dict[Int, Real]]:
579+
def _get_data(self) -> dict[Real, dict[Int, Real]]: # type: ignore[override]
582580
return self._data_samples
583581

584-
def _set_data(self, data: dict[Real, dict[Int, Real]]) -> None:
582+
def _set_data(self, data: dict[Real, dict[Int, Real]]) -> None: # type: ignore[override]
585583
if data:
586584
for x, samples in data.items():
587585
self.tell_many_at_point(x, samples)
@@ -616,7 +614,7 @@ def plot(self):
616614
return p.redim(x={"range": plot_bounds})
617615

618616

619-
def decreasing_dict() -> dict:
617+
def decreasing_dict() -> ItemSortedDict:
620618
"""This initialization orders the dictionary from large to small values"""
621619

622620
def sorting_rule(key, value):

0 commit comments

Comments
 (0)