2
2
3
3
import sys
4
4
from collections import defaultdict
5
- from functools import partial
6
5
from math import sqrt
7
6
from operator import attrgetter
8
- from typing import Any , Callable , List , Optional , Set , Tuple , Union
7
+ from typing import Callable , List , Optional , Set , Tuple , Union
9
8
10
9
import numpy as np
11
- from numpy import ufunc
12
10
from scipy .linalg import norm
13
11
from sortedcontainers import SortedSet
14
12
@@ -142,11 +140,7 @@ class _Interval:
142
140
]
143
141
144
142
def __init__ (
145
- self ,
146
- a : Union [int , np .float64 ],
147
- b : Union [int , np .float64 ],
148
- depth : int ,
149
- rdepth : int ,
143
+ self , a : Union [int , float ], b : Union [int , float ], depth : int , rdepth : int ,
150
144
) -> None :
151
145
self .children = []
152
146
self .data = {}
@@ -213,7 +207,7 @@ def split(self) -> List["_Interval"]:
213
207
def calc_igral (self ) -> None :
214
208
self .igral = (self .b - self .a ) * self .c [0 ] / sqrt (2 )
215
209
216
- def update_heuristic_err (self , value : Union [ np . float64 , float ] ) -> None :
210
+ def update_heuristic_err (self , value : float ) -> None :
217
211
"""Sets the error of an interval using a heuristic (half the error of
218
212
the parent) when the actual error cannot be calculated due to its
219
213
parents not being finished yet. This error is propagated down to its
@@ -347,10 +341,7 @@ def __repr__(self) -> str:
347
341
348
342
class IntegratorLearner (BaseLearner ):
349
343
def __init__ (
350
- self ,
351
- function : Union [partial , ufunc , Callable ],
352
- bounds : Tuple [int , int ],
353
- tol : float ,
344
+ self , function : Callable , bounds : Tuple [int , int ], tol : float ,
354
345
) -> None :
355
346
"""
356
347
Parameters
@@ -403,7 +394,7 @@ def __init__(
403
394
def approximating_intervals (self ) -> Set ["_Interval" ]:
404
395
return self .first_ival .done_leaves
405
396
406
- def tell (self , point : np . float64 , value : np . float64 ) -> None :
397
+ def tell (self , point : float , value : float ) -> None :
407
398
if point not in self .x_mapping :
408
399
raise ValueError (f"Point { point } doesn't belong to any interval" )
409
400
self .data [point ] = value
@@ -460,23 +451,15 @@ def add_ival(self, ival: "_Interval") -> None:
460
451
self ._stack .append (x )
461
452
self .ivals .add (ival )
462
453
463
- def ask (
464
- self , n : int , tell_pending : bool = True
465
- ) -> Union [
466
- Tuple [List [np .float64 ], List [np .float64 ]], Tuple [List [np .float64 ], List [float ]]
467
- ]:
454
+ def ask (self , n : int , tell_pending : bool = True ) -> Tuple [List [float ], List [float ]]:
468
455
"""Choose points for learners."""
469
456
if not tell_pending :
470
457
with restore (self ):
471
458
return self ._ask_and_tell_pending (n )
472
459
else :
473
460
return self ._ask_and_tell_pending (n )
474
461
475
- def _ask_and_tell_pending (
476
- self , n : int
477
- ) -> Union [
478
- Tuple [List [np .float64 ], List [np .float64 ]], Tuple [List [np .float64 ], List [float ]]
479
- ]:
462
+ def _ask_and_tell_pending (self , n : int ) -> Tuple [List [float ], List [float ]]:
480
463
points , loss_improvements = self .pop_from_stack (n )
481
464
n_left = n - len (points )
482
465
while n_left > 0 :
@@ -492,13 +475,7 @@ def _ask_and_tell_pending(
492
475
493
476
return points , loss_improvements
494
477
495
- def pop_from_stack (
496
- self , n : int
497
- ) -> Union [
498
- Tuple [List [np .float64 ], List [np .float64 ]],
499
- Tuple [List [Any ], List [Any ]],
500
- Tuple [List [np .float64 ], List [float ]],
501
- ]:
478
+ def pop_from_stack (self , n : int ) -> Tuple [List [float ], List [float ]]:
502
479
points = self ._stack [:n ]
503
480
self ._stack = self ._stack [n :]
504
481
loss_improvements = [
@@ -509,7 +486,7 @@ def pop_from_stack(
509
486
def remove_unfinished (self ):
510
487
pass
511
488
512
- def _fill_stack (self ) -> List [np . float64 ]:
489
+ def _fill_stack (self ) -> List [float ]:
513
490
# XXX: to-do if all the ivals have err=inf, take the interval
514
491
# with the lowest rdepth and no children.
515
492
force_split = bool (self .priority_split )
@@ -550,11 +527,11 @@ def npoints(self) -> int:
550
527
return len (self .data )
551
528
552
529
@property
553
- def igral (self ) -> np . float64 :
530
+ def igral (self ) -> float :
554
531
return sum (i .igral for i in self .approximating_intervals )
555
532
556
533
@property
557
- def err (self ) -> np . float64 :
534
+ def err (self ) -> float :
558
535
if self .approximating_intervals :
559
536
err = sum (i .err for i in self .approximating_intervals )
560
537
if err > sys .float_info .max :
0 commit comments