Skip to content

Commit fa0a0ef

Browse files
committed
add type annotations for adaptive/utils.py
1 parent d6ba748 commit fa0a0ef

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

adaptive/utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
import pickle
66
from contextlib import contextmanager
77
from itertools import product
8+
from typing import Any, Callable, Iterator
89

910
from atomicwrites import AtomicWriter
1011

12+
from adaptive.learner.base_learner import BaseLearner
13+
1114

1215
def named_product(**items):
1316
names = items.keys()
@@ -16,7 +19,7 @@ def named_product(**items):
1619

1720

1821
@contextmanager
19-
def restore(*learners):
22+
def restore(*learners) -> Iterator[None]:
2023
states = [learner.__getstate__() for learner in learners]
2124
try:
2225
yield
@@ -25,7 +28,7 @@ def restore(*learners):
2528
learner.__setstate__(state)
2629

2730

28-
def cache_latest(f):
31+
def cache_latest(f: Callable) -> Callable:
2932
"""Cache the latest return value of the function and add it
3033
as 'self._cache[f.__name__]'."""
3134

@@ -40,7 +43,7 @@ def wrapper(*args, **kwargs):
4043
return wrapper
4144

4245

43-
def save(fname, data, compress=True):
46+
def save(fname: str, data: Any, compress: bool = True) -> None:
4447
fname = os.path.expanduser(fname)
4548
dirname = os.path.dirname(fname)
4649
if dirname:
@@ -54,22 +57,22 @@ def save(fname, data, compress=True):
5457
f.write(blob)
5558

5659

57-
def load(fname, compress=True):
60+
def load(fname: str, compress: bool = True):
5861
fname = os.path.expanduser(fname)
5962
_open = gzip.open if compress else open
6063
with _open(fname, "rb") as f:
6164
return pickle.load(f)
6265

6366

64-
def copy_docstring_from(other):
67+
def copy_docstring_from(other: Callable) -> Callable:
6568
def decorator(method):
6669
return functools.wraps(other)(method)
6770

6871
return decorator
6972

7073

7174
class _RequireAttrsABCMeta(abc.ABCMeta):
72-
def __call__(self, *args, **kwargs):
75+
def __call__(self, *args, **kwargs) -> BaseLearner:
7376
obj = super().__call__(*args, **kwargs)
7477
for name, type_ in obj.__annotations__.items():
7578
try:

0 commit comments

Comments
 (0)