1
1
from functools import partial
2
2
from operator import attrgetter
3
+ from typing import Callable , List , Set , Union
3
4
4
5
import numpy as np
5
6
import pytest
6
7
7
- from adaptive .learner import IntegratorLearner
8
8
from adaptive .learner .integrator_coeffs import ns
9
- from adaptive .learner .integrator_learner import DivergentIntegralError
9
+ from adaptive .learner .integrator_learner import (
10
+ DivergentIntegralError ,
11
+ IntegratorLearner ,
12
+ _Interval ,
13
+ )
10
14
11
15
from .algorithm_4 import DivergentIntegralError as A4DivergentIntegralError
12
16
from .algorithm_4 import algorithm_4 , f0 , f7 , f21 , f24 , f63 , fdiv
13
17
14
18
eps = np .spacing (1 )
15
19
16
20
17
- def run_integrator_learner (f , a , b , tol , n ):
21
+ def run_integrator_learner (
22
+ f : Union [partial , Callable ], a : int , b : int , tol : float , n : int
23
+ ) -> IntegratorLearner :
18
24
learner = IntegratorLearner (f , bounds = (a , b ), tol = tol )
19
25
for _ in range (n ):
20
26
points , _ = learner .ask (1 )
21
27
learner .tell_many (points , map (learner .function , points ))
22
28
return learner
23
29
24
30
25
- def equal_ival (ival , other , * , verbose = False ):
31
+ def equal_ival (ival : _Interval , other : _Interval , * , verbose = False ) -> bool :
26
32
"""Note: Implementing __eq__ breaks SortedContainers in some way."""
27
33
if ival .depth_complete is None :
28
34
if verbose :
@@ -42,7 +48,9 @@ def equal_ival(ival, other, *, verbose=False):
42
48
return all (same_slots )
43
49
44
50
45
- def equal_ivals (ivals , other , * , verbose = False ):
51
+ def equal_ivals (
52
+ ivals : Set [_Interval ], other : List [_Interval ], * , verbose = False
53
+ ) -> bool :
46
54
"""Note: `other` is a list of ivals."""
47
55
if len (ivals ) != len (other ):
48
56
if verbose :
@@ -56,7 +64,7 @@ def equal_ivals(ivals, other, *, verbose=False):
56
64
)
57
65
58
66
59
- def same_ivals (f , a , b , tol ) :
67
+ def same_ivals (f : Callable , a : int , b : int , tol : float ) -> bool :
60
68
igral , err , n , ivals = algorithm_4 (f , a , b , tol )
61
69
62
70
learner = run_integrator_learner (f , a , b , tol , n )
@@ -71,15 +79,15 @@ def same_ivals(f, a, b, tol):
71
79
72
80
# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
73
81
@pytest .mark .xfail
74
- def test_that_gives_same_intervals_as_reference_implementation ():
82
+ def test_that_gives_same_intervals_as_reference_implementation () -> None :
75
83
for i , args in enumerate (
76
84
[[f0 , 0 , 3 , 1e-5 ], [f7 , 0 , 1 , 1e-6 ], [f21 , 0 , 1 , 1e-3 ], [f24 , 0 , 3 , 1e-3 ]]
77
85
):
78
86
assert same_ivals (* args ), f"Function { i } "
79
87
80
88
81
89
@pytest .mark .xfail
82
- def test_machine_precision ():
90
+ def test_machine_precision () -> None :
83
91
f , a , b , tol = [partial (f63 , alpha = 0.987654321 , beta = 0.45 ), 0 , 1 , 1e-10 ]
84
92
igral , err , n , ivals = algorithm_4 (f , a , b , tol )
85
93
@@ -92,7 +100,7 @@ def test_machine_precision():
92
100
assert equal_ivals (learner .ivals , ivals , verbose = True )
93
101
94
102
95
- def test_machine_precision2 ():
103
+ def test_machine_precision2 () -> None :
96
104
f , a , b , tol = [partial (f63 , alpha = 0.987654321 , beta = 0.45 ), 0 , 1 , 1e-10 ]
97
105
igral , err , n , ivals = algorithm_4 (f , a , b , tol )
98
106
@@ -102,7 +110,7 @@ def test_machine_precision2():
102
110
np .testing .assert_almost_equal (err , learner .err )
103
111
104
112
105
- def test_divergence ():
113
+ def test_divergence () -> None :
106
114
"""This function should raise a DivergentIntegralError."""
107
115
f , a , b , tol = fdiv , 0 , 1 , 1e-6
108
116
with pytest .raises (A4DivergentIntegralError ) as e :
@@ -114,22 +122,22 @@ def test_divergence():
114
122
run_integrator_learner (f , a , b , tol , n )
115
123
116
124
117
- def test_choosing_and_adding_points_one_by_one ():
125
+ def test_choosing_and_adding_points_one_by_one () -> None :
118
126
learner = IntegratorLearner (f24 , bounds = (0 , 3 ), tol = 1e-10 )
119
127
for _ in range (1000 ):
120
128
xs , _ = learner .ask (1 )
121
129
for x in xs :
122
130
learner .tell (x , learner .function (x ))
123
131
124
132
125
- def test_choosing_and_adding_multiple_points_at_once ():
133
+ def test_choosing_and_adding_multiple_points_at_once () -> None :
126
134
learner = IntegratorLearner (f24 , bounds = (0 , 3 ), tol = 1e-10 )
127
135
xs , _ = learner .ask (100 )
128
136
for x in xs :
129
137
learner .tell (x , learner .function (x ))
130
138
131
139
132
- def test_adding_points_and_skip_one_point ():
140
+ def test_adding_points_and_skip_one_point () -> None :
133
141
learner = IntegratorLearner (f24 , bounds = (0 , 3 ), tol = 1e-10 )
134
142
xs , _ = learner .ask (17 )
135
143
skip_x = xs [1 ]
@@ -160,7 +168,7 @@ def test_adding_points_and_skip_one_point():
160
168
161
169
# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
162
170
@pytest .mark .xfail
163
- def test_tell_in_random_order (first_add_33 = False ):
171
+ def test_tell_in_random_order (first_add_33 : bool = False ) -> None :
164
172
from operator import attrgetter
165
173
import random
166
174
@@ -219,11 +227,11 @@ def test_tell_in_random_order(first_add_33=False):
219
227
220
228
# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
221
229
@pytest .mark .xfail
222
- def test_tell_in_random_order_first_add_33 ():
230
+ def test_tell_in_random_order_first_add_33 () -> None :
223
231
test_tell_in_random_order (first_add_33 = True )
224
232
225
233
226
- def test_approximating_intervals ():
234
+ def test_approximating_intervals () -> None :
227
235
import random
228
236
229
237
learner = IntegratorLearner (f24 , bounds = (0 , 3 ), tol = 1e-10 )
@@ -252,7 +260,7 @@ def test_removed_choose_mutiple_points_at_once():
252
260
assert list (learner .approximating_intervals )[0 ] == learner .first_ival
253
261
254
262
255
- def test_removed_ask_one_by_one ():
263
+ def test_removed_ask_one_by_one () -> None :
256
264
with pytest .raises (RuntimeError ):
257
265
# This test should raise because integrating np.exp should be done
258
266
# after the 33th point
0 commit comments