5
5
from collections import defaultdict
6
6
from copy import deepcopy
7
7
from math import hypot
8
- from numbers import Integral as Int
9
- from numbers import Real
10
8
from typing import Callable , DefaultDict , Iterable , List , Sequence , Tuple
11
9
12
10
import numpy as np
16
14
17
15
from adaptive .learner .learner1D import Learner1D , _get_intervals
18
16
from adaptive .notebook_integration import ensure_holoviews
17
+ from adaptive .types import Int , Real
19
18
from adaptive .utils import assign_defaults , partial_function_from_dataframe
20
19
21
20
try :
@@ -99,7 +98,7 @@ def __init__(
99
98
if min_samples > max_samples :
100
99
raise ValueError ("max_samples should be larger than min_samples." )
101
100
102
- super ().__init__ (function , bounds , loss_per_interval )
101
+ super ().__init__ (function , bounds , loss_per_interval ) # type: ignore[arg-type]
103
102
104
103
self .delta = delta
105
104
self .alpha = alpha
@@ -110,7 +109,7 @@ def __init__(
110
109
111
110
# Contains all samples f(x) for each
112
111
# point x in the form {x0: {0: f_0(x0), 1: f_1(x0), ...}, ...}
113
- self ._data_samples = SortedDict ()
112
+ self ._data_samples : SortedDict [ float , dict [ int , Real ]] = SortedDict ()
114
113
# Contains the number of samples taken
115
114
# at each point x in the form {x0: n0, x1: n1, ...}
116
115
self ._number_samples = SortedDict ()
@@ -124,15 +123,14 @@ def __init__(
124
123
# form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
125
124
self ._distances : dict [Real , float ] = decreasing_dict ()
126
125
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
127
- self .rescaled_error : dict [Real , float ] = decreasing_dict ()
128
- self ._check_required_attributes ()
126
+ self .rescaled_error : ItemSortedDict [Real , float ] = decreasing_dict ()
129
127
130
128
def new (self ) -> AverageLearner1D :
131
129
"""Create a copy of `~adaptive.AverageLearner1D` without the data."""
132
130
return AverageLearner1D (
133
131
self .function ,
134
132
self .bounds ,
135
- self .loss_per_interval ,
133
+ self .loss_per_interval , # type: ignore[arg-type]
136
134
self .delta ,
137
135
self .alpha ,
138
136
self .neighbor_sampling ,
@@ -164,7 +162,7 @@ def to_numpy(self, mean: bool = False) -> np.ndarray:
164
162
]
165
163
)
166
164
167
- def to_dataframe (
165
+ def to_dataframe ( # type: ignore[override]
168
166
self ,
169
167
mean : bool = False ,
170
168
with_default_function_args : bool = True ,
@@ -202,10 +200,10 @@ def to_dataframe(
202
200
if not with_pandas :
203
201
raise ImportError ("pandas is not installed." )
204
202
if mean :
205
- data = sorted (self .data .items ())
203
+ data : list [ tuple [ Real , Real ]] = sorted (self .data .items ())
206
204
columns = [x_name , y_name ]
207
205
else :
208
- data = [
206
+ data : list [ tuple [ int , Real , Real ]] = [ # type: ignore[no-redef]
209
207
(seed , x , y )
210
208
for x , seed_y in sorted (self ._data_samples .items ())
211
209
for seed , y in sorted (seed_y .items ())
@@ -218,7 +216,7 @@ def to_dataframe(
218
216
assign_defaults (self .function , df , function_prefix )
219
217
return df
220
218
221
- def load_dataframe (
219
+ def load_dataframe ( # type: ignore[override]
222
220
self ,
223
221
df : pandas .DataFrame ,
224
222
with_default_function_args : bool = True ,
@@ -258,7 +256,7 @@ def load_dataframe(
258
256
self .function , df , function_prefix
259
257
)
260
258
261
- def ask (self , n : int , tell_pending : bool = True ) -> tuple [Points , list [float ]]:
259
+ def ask (self , n : int , tell_pending : bool = True ) -> tuple [Points , list [float ]]: # type: ignore[override]
262
260
"""Return 'n' points that are expected to maximally reduce the loss."""
263
261
# If some point is undersampled, resample it
264
262
if len (self ._undersampled_points ):
@@ -311,18 +309,18 @@ def _ask_for_new_point(self, n: int) -> tuple[Points, list[float]]:
311
309
new point, since in general n << min_samples and this point will need
312
310
to be resampled many more times"""
313
311
points , (loss_improvement ,) = self ._ask_points_without_adding (1 )
314
- points = [(seed , x ) for seed , x in zip (range (n ), n * points )]
312
+ seed_points = [(seed , x ) for seed , x in zip (range (n ), n * points )]
315
313
loss_improvements = [loss_improvement / n ] * n
316
- return points , loss_improvements
314
+ return seed_points , loss_improvements # type: ignore[return-value]
317
315
318
- def tell_pending (self , seed_x : Point ) -> None :
316
+ def tell_pending (self , seed_x : Point ) -> None : # type: ignore[override]
319
317
_ , x = seed_x
320
318
self .pending_points .add (seed_x )
321
319
if x not in self .data :
322
320
self ._update_neighbors (x , self .neighbors_combined )
323
321
self ._update_losses (x , real = False )
324
322
325
- def tell (self , seed_x : Point , y : Real ) -> None :
323
+ def tell (self , seed_x : Point , y : Real ) -> None : # type: ignore[override]
326
324
seed , x = seed_x
327
325
if y is None :
328
326
raise TypeError (
@@ -493,7 +491,7 @@ def _calc_error_in_mean(self, ys: Iterable[Real], y_avg: Real, n: int) -> float:
493
491
t_student = scipy .stats .t .ppf (1 - self .alpha , df = n - 1 )
494
492
return t_student * (variance_in_mean / n ) ** 0.5
495
493
496
- def tell_many (
494
+ def tell_many ( # type: ignore[override]
497
495
self , xs : Points | np .ndarray , ys : Sequence [Real ] | np .ndarray
498
496
) -> None :
499
497
# Check that all x are within the bounds
@@ -578,10 +576,10 @@ def tell_many_at_point(self, x: Real, seed_y_mapping: dict[int, Real]) -> None:
578
576
self ._update_interpolated_loss_in_interval (* interval )
579
577
self ._oldscale = deepcopy (self ._scale )
580
578
581
- def _get_data (self ) -> dict [Real , dict [Int , Real ]]:
579
+ def _get_data (self ) -> dict [Real , dict [Int , Real ]]: # type: ignore[override]
582
580
return self ._data_samples
583
581
584
- def _set_data (self , data : dict [Real , dict [Int , Real ]]) -> None :
582
+ def _set_data (self , data : dict [Real , dict [Int , Real ]]) -> None : # type: ignore[override]
585
583
if data :
586
584
for x , samples in data .items ():
587
585
self .tell_many_at_point (x , samples )
@@ -616,7 +614,7 @@ def plot(self):
616
614
return p .redim (x = {"range" : plot_bounds })
617
615
618
616
619
- def decreasing_dict () -> dict :
617
+ def decreasing_dict () -> ItemSortedDict :
620
618
"""This initialization orders the dictionary from large to small values"""
621
619
622
620
def sorting_rule (key , value ):
0 commit comments