Skip to content

Commit 4a0ac28

Browse files
committed
add type annotations for adaptive/learner/base_learner.py
1 parent fcae8e0 commit 4a0ac28

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

adaptive/learner/base_learner.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import abc
22
from contextlib import suppress
33
from copy import deepcopy
4+
from typing import Any, Callable, Dict
45

56
from adaptive.utils import _RequireAttrsABCMeta, load, save
67

78

8-
def uses_nth_neighbors(n):
9+
def uses_nth_neighbors(n: int) -> Callable:
910
"""Decorator to specify how many neighboring intervals the loss function uses.
1011
1112
Wraps loss functions to indicate that they expect intervals together
@@ -84,7 +85,7 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
8485
npoints: int
8586
pending_points: set
8687

87-
def tell(self, x, y):
88+
def tell(self, x: Any, y) -> None:
8889
"""Tell the learner about a single value.
8990
9091
Parameters
@@ -94,7 +95,7 @@ def tell(self, x, y):
9495
"""
9596
self.tell_many([x], [y])
9697

97-
def tell_many(self, xs, ys):
98+
def tell_many(self, xs: Any, ys: Any) -> None:
9899
"""Tell the learner about some values.
99100
100101
Parameters
@@ -161,7 +162,7 @@ def copy_from(self, other):
161162
"""
162163
self._set_data(other._get_data())
163164

164-
def save(self, fname, compress=True):
165+
def save(self, fname: str, compress: bool = True) -> None:
165166
"""Save the data of the learner into a pickle file.
166167
167168
Parameters
@@ -175,7 +176,7 @@ def save(self, fname, compress=True):
175176
data = self._get_data()
176177
save(fname, data, compress)
177178

178-
def load(self, fname, compress=True):
179+
def load(self, fname: str, compress: bool = True) -> None:
179180
"""Load the data of a learner from a pickle file.
180181
181182
Parameters
@@ -190,8 +191,8 @@ def load(self, fname, compress=True):
190191
data = load(fname, compress)
191192
self._set_data(data)
192193

193-
def __getstate__(self):
194+
def __getstate__(self) -> Dict[str, Any]:
194195
return deepcopy(self.__dict__)
195196

196-
def __setstate__(self, state):
197+
def __setstate__(self, state: Dict[str, Any]) -> None:
197198
self.__dict__ = state

0 commit comments

Comments
 (0)