1
1
import collections
2
+ from typing import Callable , List , Tuple , Union
2
3
3
4
import numpy as np
5
+ from numpy import float64
4
6
from skopt import Optimizer
5
7
6
8
from adaptive .learner .base_learner import BaseLearner
@@ -23,13 +25,15 @@ class SKOptLearner(Optimizer, BaseLearner):
23
25
Arguments to pass to ``skopt.Optimizer``.
24
26
"""
25
27
26
- def __init__ (self , function , ** kwargs ):
28
+ def __init__ (self , function : Callable , ** kwargs ) -> None :
27
29
self .function = function
28
30
self .pending_points = set ()
29
31
self .data = collections .OrderedDict ()
30
32
super ().__init__ (** kwargs )
31
33
32
- def tell (self , x , y , fit = True ):
34
+ def tell (
35
+ self , x : Union [float64 , List [float64 ]], y : float64 , fit : bool = True
36
+ ) -> None :
33
37
if isinstance (x , collections .abc .Iterable ):
34
38
self .pending_points .discard (tuple (x ))
35
39
self .data [tuple (x )] = y
@@ -48,7 +52,7 @@ def remove_unfinished(self):
48
52
pass
49
53
50
54
@cache_latest
51
- def loss (self , real = True ):
55
+ def loss (self , real : bool = True ) -> Union [ float64 , float ] :
52
56
if not self .models :
53
57
return np .inf
54
58
else :
@@ -58,7 +62,14 @@ def loss(self, real=True):
58
62
# estimator of loss, but it is the cheapest.
59
63
return 1 - model .score (self .Xi , self .yi )
60
64
61
- def ask (self , n , tell_pending = True ):
65
+ def ask (
66
+ self , n : int , tell_pending : bool = True
67
+ ) -> Union [
68
+ Tuple [List [float64 ], List [float64 ]],
69
+ Tuple [List [List [float64 ]], List [float64 ]],
70
+ Tuple [List [List [float64 ]], List [float ]],
71
+ Tuple [List [float64 ], List [float ]],
72
+ ]:
62
73
if not tell_pending :
63
74
raise NotImplementedError (
64
75
"Asking points is an irreversible "
@@ -72,7 +83,7 @@ def ask(self, n, tell_pending=True):
72
83
return [p [0 ] for p in points ], [self .loss () / n ] * n
73
84
74
85
@property
75
- def npoints (self ):
86
+ def npoints (self ) -> int :
76
87
"""Number of evaluated points."""
77
88
return len (self .Xi )
78
89
0 commit comments