2
2
3
3
import sys
4
4
from collections import defaultdict
5
+ from functools import partial
5
6
from math import sqrt
6
7
from operator import attrgetter
8
+ from typing import Any , Callable , List , Optional , Set , Tuple , Union
7
9
8
10
import numpy as np
11
+ from numpy import bool_ , float64 , ndarray , ufunc
9
12
from scipy .linalg import norm
10
13
from sortedcontainers import SortedSet
11
14
30
33
)
31
34
32
35
33
- def _downdate (c , nans , depth ) :
36
+ def _downdate (c : ndarray , nans : List [ int ] , depth : int ) -> ndarray :
34
37
# This is algorithm 5 from the thesis of Pedro Gonnet.
35
38
b = b_def [depth ].copy ()
36
39
m = ns [depth ] - 1
@@ -48,7 +51,7 @@ def _downdate(c, nans, depth):
48
51
return c
49
52
50
53
51
- def _zero_nans (fx ) :
54
+ def _zero_nans (fx : ndarray ) -> List [ int ] :
52
55
"""Caution: this function modifies fx."""
53
56
nans = []
54
57
for i in range (len (fx )):
@@ -58,7 +61,7 @@ def _zero_nans(fx):
58
61
return nans
59
62
60
63
61
- def _calc_coeffs (fx , depth ) :
64
+ def _calc_coeffs (fx : ndarray , depth : int ) -> ndarray :
62
65
"""Caution: this function modifies fx."""
63
66
nans = _zero_nans (fx )
64
67
c_new = V_inv [depth ] @ fx
@@ -138,7 +141,9 @@ class _Interval:
138
141
"removed" ,
139
142
]
140
143
141
- def __init__ (self , a , b , depth , rdepth ):
144
+ def __init__ (
145
+ self , a : Union [int , float64 ], b : Union [int , float64 ], depth : int , rdepth : int
146
+ ) -> None :
142
147
self .children = []
143
148
self .data = {}
144
149
self .a = a
@@ -150,15 +155,15 @@ def __init__(self, a, b, depth, rdepth):
150
155
self .removed = False
151
156
152
157
@classmethod
153
- def make_first (cls , a , b , depth = 2 ) :
158
+ def make_first (cls , a : int , b : int , depth : int = 2 ) -> "_Interval" :
154
159
ival = _Interval (a , b , depth , rdepth = 1 )
155
160
ival .ndiv = 0
156
161
ival .parent = None
157
162
ival .err = sys .float_info .max # needed because inf/2 == inf
158
163
return ival
159
164
160
165
@property
161
- def T (self ):
166
+ def T (self ) -> ndarray :
162
167
"""Get the correct shift matrix.
163
168
164
169
Should only be called on children of a split interval.
@@ -169,24 +174,24 @@ def T(self):
169
174
assert left != right
170
175
return T_left if left else T_right
171
176
172
- def refinement_complete (self , depth ) :
177
+ def refinement_complete (self , depth : int ) -> bool :
173
178
"""The interval has all the y-values to calculate the intergral."""
174
179
if len (self .data ) < ns [depth ]:
175
180
return False
176
181
return all (p in self .data for p in self .points (depth ))
177
182
178
- def points (self , depth = None ):
183
+ def points (self , depth : Optional [ int ] = None ) -> ndarray :
179
184
if depth is None :
180
185
depth = self .depth
181
186
a = self .a
182
187
b = self .b
183
188
return (a + b ) / 2 + (b - a ) * xi [depth ] / 2
184
189
185
- def refine (self ):
190
+ def refine (self ) -> "_Interval" :
186
191
self .depth += 1
187
192
return self
188
193
189
- def split (self ):
194
+ def split (self ) -> List [ "_Interval" ] :
190
195
points = self .points ()
191
196
m = points [len (points ) // 2 ]
192
197
ivals = [
@@ -201,10 +206,10 @@ def split(self):
201
206
202
207
return ivals
203
208
204
- def calc_igral (self ):
209
+ def calc_igral (self ) -> None :
205
210
self .igral = (self .b - self .a ) * self .c [0 ] / sqrt (2 )
206
211
207
- def update_heuristic_err (self , value ) :
212
+ def update_heuristic_err (self , value : Union [ float64 , float ]) -> None :
208
213
"""Sets the error of an interval using a heuristic (half the error of
209
214
the parent) when the actual error cannot be calculated due to its
210
215
parents not being finished yet. This error is propagated down to its
@@ -217,7 +222,7 @@ def update_heuristic_err(self, value):
217
222
continue
218
223
child .update_heuristic_err (value / 2 )
219
224
220
- def calc_err (self , c_old ) :
225
+ def calc_err (self , c_old : ndarray ) -> float :
221
226
c_new = self .c
222
227
c_diff = np .zeros (max (len (c_old ), len (c_new )))
223
228
c_diff [: len (c_old )] = c_old
@@ -229,7 +234,7 @@ def calc_err(self, c_old):
229
234
child .update_heuristic_err (self .err / 2 )
230
235
return c_diff
231
236
232
- def calc_ndiv (self ):
237
+ def calc_ndiv (self ) -> None :
233
238
div = self .parent .c00 and self .c00 / self .parent .c00 > 2
234
239
self .ndiv += div
235
240
@@ -240,15 +245,17 @@ def calc_ndiv(self):
240
245
for child in self .children :
241
246
child .update_ndiv_recursively ()
242
247
243
- def update_ndiv_recursively (self ):
248
+ def update_ndiv_recursively (self ) -> None :
244
249
self .ndiv += 1
245
250
if self .ndiv > ndiv_max and 2 * self .ndiv > self .rdepth :
246
251
raise DivergentIntegralError
247
252
248
253
for child in self .children :
249
254
child .update_ndiv_recursively ()
250
255
251
- def complete_process (self , depth ):
256
+ def complete_process (
257
+ self , depth : int
258
+ ) -> Union [Tuple [bool , bool ], Tuple [bool , bool_ ]]:
252
259
"""Calculate the integral contribution and error from this interval,
253
260
and update the done leaves of all ancestor intervals."""
254
261
assert self .depth_complete is None or self .depth_complete == depth - 1
@@ -323,7 +330,7 @@ def complete_process(self, depth):
323
330
324
331
return force_split , remove
325
332
326
- def __repr__ (self ):
333
+ def __repr__ (self ) -> str :
327
334
lst = [
328
335
f"(a, b)=({ self .a :.5f} , { self .b :.5f} )" ,
329
336
f"depth={ self .depth } " ,
@@ -335,7 +342,12 @@ def __repr__(self):
335
342
336
343
337
344
class IntegratorLearner (BaseLearner ):
338
- def __init__ (self , function , bounds , tol ):
345
+ def __init__ (
346
+ self ,
347
+ function : Union [partial , ufunc , Callable ],
348
+ bounds : Tuple [int , int ],
349
+ tol : float ,
350
+ ) -> None :
339
351
"""
340
352
Parameters
341
353
----------
@@ -384,10 +396,10 @@ def __init__(self, function, bounds, tol):
384
396
self .first_ival = ival
385
397
386
398
@property
387
- def approximating_intervals (self ):
399
+ def approximating_intervals (self ) -> Set [ "_Interval" ] :
388
400
return self .first_ival .done_leaves
389
401
390
- def tell (self , point , value ) :
402
+ def tell (self , point : float64 , value : float64 ) -> None :
391
403
if point not in self .x_mapping :
392
404
raise ValueError (f"Point { point } doesn't belong to any interval" )
393
405
self .data [point ] = value
@@ -423,7 +435,7 @@ def tell(self, point, value):
423
435
def tell_pending (self ):
424
436
pass
425
437
426
- def propagate_removed (self , ival ) :
438
+ def propagate_removed (self , ival : "_Interval" ) -> None :
427
439
def _propagate_removed_down (ival ):
428
440
ival .removed = True
429
441
self .ivals .discard (ival )
@@ -433,7 +445,7 @@ def _propagate_removed_down(ival):
433
445
434
446
_propagate_removed_down (ival )
435
447
436
- def add_ival (self , ival ) :
448
+ def add_ival (self , ival : "_Interval" ) -> None :
437
449
for x in ival .points ():
438
450
# Update the mappings
439
451
self .x_mapping [x ].add (ival )
@@ -444,15 +456,19 @@ def add_ival(self, ival):
444
456
self ._stack .append (x )
445
457
self .ivals .add (ival )
446
458
447
- def ask (self , n , tell_pending = True ):
459
+ def ask (
460
+ self , n : int , tell_pending : bool = True
461
+ ) -> Union [Tuple [List [float64 ], List [float64 ]], Tuple [List [float64 ], List [float ]]]:
448
462
"""Choose points for learners."""
449
463
if not tell_pending :
450
464
with restore (self ):
451
465
return self ._ask_and_tell_pending (n )
452
466
else :
453
467
return self ._ask_and_tell_pending (n )
454
468
455
- def _ask_and_tell_pending (self , n ):
469
+ def _ask_and_tell_pending (
470
+ self , n : int
471
+ ) -> Union [Tuple [List [float64 ], List [float64 ]], Tuple [List [float64 ], List [float ]]]:
456
472
points , loss_improvements = self .pop_from_stack (n )
457
473
n_left = n - len (points )
458
474
while n_left > 0 :
@@ -468,7 +484,13 @@ def _ask_and_tell_pending(self, n):
468
484
469
485
return points , loss_improvements
470
486
471
- def pop_from_stack (self , n ):
487
+ def pop_from_stack (
488
+ self , n : int
489
+ ) -> Union [
490
+ Tuple [List [float64 ], List [float64 ]],
491
+ Tuple [List [Any ], List [Any ]],
492
+ Tuple [List [float64 ], List [float ]],
493
+ ]:
472
494
points = self ._stack [:n ]
473
495
self ._stack = self ._stack [n :]
474
496
loss_improvements = [
@@ -479,7 +501,7 @@ def pop_from_stack(self, n):
479
501
def remove_unfinished (self ):
480
502
pass
481
503
482
- def _fill_stack (self ):
504
+ def _fill_stack (self ) -> List [ float64 ] :
483
505
# XXX: to-do if all the ivals have err=inf, take the interval
484
506
# with the lowest rdepth and no children.
485
507
force_split = bool (self .priority_split )
@@ -515,16 +537,16 @@ def _fill_stack(self):
515
537
return self ._stack
516
538
517
539
@property
518
- def npoints (self ):
540
+ def npoints (self ) -> int :
519
541
"""Number of evaluated points."""
520
542
return len (self .data )
521
543
522
544
@property
523
- def igral (self ):
545
+ def igral (self ) -> float64 :
524
546
return sum (i .igral for i in self .approximating_intervals )
525
547
526
548
@property
527
- def err (self ):
549
+ def err (self ) -> float64 :
528
550
if self .approximating_intervals :
529
551
err = sum (i .err for i in self .approximating_intervals )
530
552
if err > sys .float_info .max :
0 commit comments