1
1
import functools
2
2
from collections import OrderedDict
3
+ from operator import itemgetter
4
+ from typing import Callable , Dict , Tuple , Union
3
5
6
+ from adaptive .learner .average_learner import AverageLearner
4
7
from adaptive .learner .base_learner import BaseLearner
8
+ from adaptive .learner .learner1D import Learner1D
9
+ from adaptive .learner .learner2D import Learner2D
10
+ from adaptive .learner .learnerND import LearnerND
5
11
from adaptive .utils import copy_docstring_from
6
12
7
13
@@ -25,13 +31,17 @@ class DataSaver:
25
31
>>> learner = DataSaver(_learner, arg_picker=itemgetter('y'))
26
32
"""
27
33
28
- def __init__ (self , learner , arg_picker ):
34
+ def __init__ (
35
+ self ,
36
+ learner : Union [Learner2D , Learner1D , LearnerND , AverageLearner ],
37
+ arg_picker : itemgetter ,
38
+ ) -> None :
29
39
self .learner = learner
30
40
self .extra_data = OrderedDict ()
31
41
self .function = learner .function
32
42
self .arg_picker = arg_picker
33
43
34
- def __getattr__ (self , attr ) :
44
+ def __getattr__ (self , attr : str ) -> Union [ Callable , int ] :
35
45
return getattr (self .learner , attr )
36
46
37
47
@copy_docstring_from (BaseLearner .tell )
@@ -44,10 +54,23 @@ def tell(self, x, result):
44
54
def tell_pending (self , x ):
45
55
self .learner .tell_pending (x )
46
56
47
- def _get_data (self ):
57
+ def _get_data (
58
+ self ,
59
+ ) -> Union [
60
+ Tuple [Dict [Union [int , float ], float ], OrderedDict ],
61
+ Tuple [OrderedDict , OrderedDict ],
62
+ Tuple [Tuple [Dict [int , float ], int , float , float ], OrderedDict ],
63
+ ]:
48
64
return self .learner ._get_data (), self .extra_data
49
65
50
- def _set_data (self , data ):
66
+ def _set_data (
67
+ self ,
68
+ data : Union [
69
+ Tuple [OrderedDict , OrderedDict ],
70
+ Tuple [Dict [Union [int , float ], float ], OrderedDict ],
71
+ Tuple [Tuple [Dict [int , float ], int , float , float ], OrderedDict ],
72
+ ],
73
+ ) -> None :
51
74
learner_data , self .extra_data = data
52
75
self .learner ._set_data (learner_data )
53
76
0 commit comments