Skip to content

Commit a0895e4

Browse files
committed
add type annotations for adaptive/learner/data_saver.py
1 parent 4a0ac28 commit a0895e4

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

adaptive/learner/data_saver.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import functools
22
from collections import OrderedDict
3+
from operator import itemgetter
4+
from typing import Callable, Dict, Tuple, Union
35

6+
from adaptive.learner.average_learner import AverageLearner
47
from adaptive.learner.base_learner import BaseLearner
8+
from adaptive.learner.learner1D import Learner1D
9+
from adaptive.learner.learner2D import Learner2D
10+
from adaptive.learner.learnerND import LearnerND
511
from adaptive.utils import copy_docstring_from
612

713

@@ -25,13 +31,17 @@ class DataSaver:
2531
>>> learner = DataSaver(_learner, arg_picker=itemgetter('y'))
2632
"""
2733

28-
def __init__(self, learner, arg_picker):
34+
def __init__(
35+
self,
36+
learner: Union[Learner2D, Learner1D, LearnerND, AverageLearner],
37+
arg_picker: itemgetter,
38+
) -> None:
2939
self.learner = learner
3040
self.extra_data = OrderedDict()
3141
self.function = learner.function
3242
self.arg_picker = arg_picker
3343

34-
def __getattr__(self, attr):
44+
def __getattr__(self, attr: str) -> Union[Callable, int]:
3545
return getattr(self.learner, attr)
3646

3747
@copy_docstring_from(BaseLearner.tell)
@@ -44,10 +54,23 @@ def tell(self, x, result):
4454
def tell_pending(self, x):
4555
self.learner.tell_pending(x)
4656

47-
def _get_data(self):
57+
def _get_data(
58+
self,
59+
) -> Union[
60+
Tuple[Dict[Union[int, float], float], OrderedDict],
61+
Tuple[OrderedDict, OrderedDict],
62+
Tuple[Tuple[Dict[int, float], int, float, float], OrderedDict],
63+
]:
4864
return self.learner._get_data(), self.extra_data
4965

50-
def _set_data(self, data):
66+
def _set_data(
67+
self,
68+
data: Union[
69+
Tuple[OrderedDict, OrderedDict],
70+
Tuple[Dict[Union[int, float], float], OrderedDict],
71+
Tuple[Tuple[Dict[int, float], int, float, float], OrderedDict],
72+
],
73+
) -> None:
5174
learner_data, self.extra_data = data
5275
self.learner._set_data(learner_data)
5376

0 commit comments

Comments
 (0)