8
8
from typing import Any , Callable , List , Optional , Set , Tuple , Union
9
9
10
10
import numpy as np
11
- from numpy import bool_ , float64 , ndarray , ufunc
11
+ from numpy import ufunc
12
12
from scipy .linalg import norm
13
13
from sortedcontainers import SortedSet
14
14
33
33
)
34
34
35
35
36
- def _downdate (c : ndarray , nans : List [int ], depth : int ) -> ndarray :
36
+ def _downdate (c : np . ndarray , nans : List [int ], depth : int ) -> np . ndarray :
37
37
# This is algorithm 5 from the thesis of Pedro Gonnet.
38
38
b = b_def [depth ].copy ()
39
39
m = ns [depth ] - 1
@@ -51,7 +51,7 @@ def _downdate(c: ndarray, nans: List[int], depth: int) -> ndarray:
51
51
return c
52
52
53
53
54
- def _zero_nans (fx : ndarray ) -> List [int ]:
54
+ def _zero_nans (fx : np . ndarray ) -> List [int ]:
55
55
"""Caution: this function modifies fx."""
56
56
nans = []
57
57
for i in range (len (fx )):
@@ -61,7 +61,7 @@ def _zero_nans(fx: ndarray) -> List[int]:
61
61
return nans
62
62
63
63
64
- def _calc_coeffs (fx : ndarray , depth : int ) -> ndarray :
64
+ def _calc_coeffs (fx : np . ndarray , depth : int ) -> np . ndarray :
65
65
"""Caution: this function modifies fx."""
66
66
nans = _zero_nans (fx )
67
67
c_new = V_inv [depth ] @ fx
@@ -142,7 +142,11 @@ class _Interval:
142
142
]
143
143
144
144
def __init__ (
145
- self , a : Union [int , float64 ], b : Union [int , float64 ], depth : int , rdepth : int
145
+ self ,
146
+ a : Union [int , np .float64 ],
147
+ b : Union [int , np .float64 ],
148
+ depth : int ,
149
+ rdepth : int ,
146
150
) -> None :
147
151
self .children = []
148
152
self .data = {}
@@ -163,7 +167,7 @@ def make_first(cls, a: int, b: int, depth: int = 2) -> "_Interval":
163
167
return ival
164
168
165
169
@property
166
- def T (self ) -> ndarray :
170
+ def T (self ) -> np . ndarray :
167
171
"""Get the correct shift matrix.
168
172
169
173
Should only be called on children of a split interval.
@@ -180,7 +184,7 @@ def refinement_complete(self, depth: int) -> bool:
180
184
return False
181
185
return all (p in self .data for p in self .points (depth ))
182
186
183
- def points (self , depth : Optional [int ] = None ) -> ndarray :
187
+ def points (self , depth : Optional [int ] = None ) -> np . ndarray :
184
188
if depth is None :
185
189
depth = self .depth
186
190
a = self .a
@@ -209,7 +213,7 @@ def split(self) -> List["_Interval"]:
209
213
def calc_igral (self ) -> None :
210
214
self .igral = (self .b - self .a ) * self .c [0 ] / sqrt (2 )
211
215
212
- def update_heuristic_err (self , value : Union [float64 , float ]) -> None :
216
+ def update_heuristic_err (self , value : Union [np . float64 , float ]) -> None :
213
217
"""Sets the error of an interval using a heuristic (half the error of
214
218
the parent) when the actual error cannot be calculated due to its
215
219
parents not being finished yet. This error is propagated down to its
@@ -222,7 +226,7 @@ def update_heuristic_err(self, value: Union[float64, float]) -> None:
222
226
continue
223
227
child .update_heuristic_err (value / 2 )
224
228
225
- def calc_err (self , c_old : ndarray ) -> float :
229
+ def calc_err (self , c_old : np . ndarray ) -> float :
226
230
c_new = self .c
227
231
c_diff = np .zeros (max (len (c_old ), len (c_new )))
228
232
c_diff [: len (c_old )] = c_old
@@ -255,7 +259,7 @@ def update_ndiv_recursively(self) -> None:
255
259
256
260
def complete_process (
257
261
self , depth : int
258
- ) -> Union [Tuple [bool , bool ], Tuple [bool , bool_ ]]:
262
+ ) -> Union [Tuple [bool , bool ], Tuple [bool , np . bool_ ]]:
259
263
"""Calculate the integral contribution and error from this interval,
260
264
and update the done leaves of all ancestor intervals."""
261
265
assert self .depth_complete is None or self .depth_complete == depth - 1
@@ -399,7 +403,7 @@ def __init__(
399
403
def approximating_intervals (self ) -> Set ["_Interval" ]:
400
404
return self .first_ival .done_leaves
401
405
402
- def tell (self , point : float64 , value : float64 ) -> None :
406
+ def tell (self , point : np . float64 , value : np . float64 ) -> None :
403
407
if point not in self .x_mapping :
404
408
raise ValueError (f"Point { point } doesn't belong to any interval" )
405
409
self .data [point ] = value
@@ -458,7 +462,9 @@ def add_ival(self, ival: "_Interval") -> None:
458
462
459
463
def ask (
460
464
self , n : int , tell_pending : bool = True
461
- ) -> Union [Tuple [List [float64 ], List [float64 ]], Tuple [List [float64 ], List [float ]]]:
465
+ ) -> Union [
466
+ Tuple [List [np .float64 ], List [np .float64 ]], Tuple [List [np .float64 ], List [float ]]
467
+ ]:
462
468
"""Choose points for learners."""
463
469
if not tell_pending :
464
470
with restore (self ):
@@ -468,7 +474,9 @@ def ask(
468
474
469
475
def _ask_and_tell_pending (
470
476
self , n : int
471
- ) -> Union [Tuple [List [float64 ], List [float64 ]], Tuple [List [float64 ], List [float ]]]:
477
+ ) -> Union [
478
+ Tuple [List [np .float64 ], List [np .float64 ]], Tuple [List [np .float64 ], List [float ]]
479
+ ]:
472
480
points , loss_improvements = self .pop_from_stack (n )
473
481
n_left = n - len (points )
474
482
while n_left > 0 :
@@ -487,9 +495,9 @@ def _ask_and_tell_pending(
487
495
def pop_from_stack (
488
496
self , n : int
489
497
) -> Union [
490
- Tuple [List [float64 ], List [float64 ]],
498
+ Tuple [List [np . float64 ], List [np . float64 ]],
491
499
Tuple [List [Any ], List [Any ]],
492
- Tuple [List [float64 ], List [float ]],
500
+ Tuple [List [np . float64 ], List [float ]],
493
501
]:
494
502
points = self ._stack [:n ]
495
503
self ._stack = self ._stack [n :]
@@ -501,7 +509,7 @@ def pop_from_stack(
501
509
def remove_unfinished (self ):
502
510
pass
503
511
504
- def _fill_stack (self ) -> List [float64 ]:
512
+ def _fill_stack (self ) -> List [np . float64 ]:
505
513
# XXX: to-do if all the ivals have err=inf, take the interval
506
514
# with the lowest rdepth and no children.
507
515
force_split = bool (self .priority_split )
@@ -542,11 +550,11 @@ def npoints(self) -> int:
542
550
return len (self .data )
543
551
544
552
@property
545
- def igral (self ) -> float64 :
553
+ def igral (self ) -> np . float64 :
546
554
return sum (i .igral for i in self .approximating_intervals )
547
555
548
556
@property
549
- def err (self ) -> float64 :
557
+ def err (self ) -> np . float64 :
550
558
if self .approximating_intervals :
551
559
err = sum (i .err for i in self .approximating_intervals )
552
560
if err > sys .float_info .max :
0 commit comments