2
2
import warnings
3
3
from collections import OrderedDict
4
4
from copy import copy
5
- from functools import partial
6
5
from math import sqrt
7
- from typing import Any , Callable , List , Optional , Tuple , Union
6
+ from typing import Any , Callable , Iterable , List , Optional , Tuple , Union
8
7
9
8
import numpy as np
10
9
from scipy import interpolate
@@ -107,7 +106,9 @@ def uniform_loss(ip: LinearNDInterpolator) -> np.ndarray:
107
106
return np .sqrt (areas (ip ))
108
107
109
108
110
- def resolution_loss_function (min_distance : int = 0 , max_distance : int = 1 ) -> Callable :
109
+ def resolution_loss_function (
110
+ min_distance : float = 0 , max_distance : float = 1
111
+ ) -> Callable :
111
112
"""Loss function that is similar to the `default_loss` function, but you
112
113
can set the maximimum and minimum size of a triangle.
113
114
@@ -353,10 +354,8 @@ class Learner2D(BaseLearner):
353
354
354
355
def __init__ (
355
356
self ,
356
- function : Union [partial , Callable ],
357
- bounds : Union [
358
- List [Tuple [int , int ]], Tuple [Tuple [int , int ], Tuple [int , int ]], np .ndarray
359
- ],
357
+ function : Callable ,
358
+ bounds : Tuple [Tuple [int , int ], Tuple [int , int ]],
360
359
loss_per_triangle : Optional [Callable ] = None ,
361
360
) -> None :
362
361
self .ndim = len (bounds )
@@ -462,7 +461,7 @@ def _data_in_bounds(self) -> Tuple[np.ndarray, np.ndarray]:
462
461
return points [inds ], values [inds ].reshape (- 1 , self .vdim )
463
462
return np .zeros ((0 , 2 )), np .zeros ((0 , self .vdim ), dtype = float )
464
463
465
- def _data_interp (self ) -> Any :
464
+ def _data_interp (self ) -> Tuple [ np . ndarray , np . ndarray ] :
466
465
if self .pending_points :
467
466
points = list (self .pending_points )
468
467
if self .bounds_are_done :
@@ -493,7 +492,7 @@ def ip(self):
493
492
)
494
493
return self .interpolator (scaled = True )
495
494
496
- def interpolator (self , * , scaled = False ) -> LinearNDInterpolator :
495
+ def interpolator (self , * , scaled : bool = False ) -> LinearNDInterpolator :
497
496
"""A `scipy.interpolate.LinearNDInterpolator` instance
498
497
containing the learner's data.
499
498
@@ -534,28 +533,13 @@ def _interpolator_combined(self) -> LinearNDInterpolator:
534
533
self ._ip_combined = interpolate .LinearNDInterpolator (points , values )
535
534
return self ._ip_combined
536
535
537
- def inside_bounds (
538
- self ,
539
- xy : Union [
540
- Tuple [int , int ],
541
- Tuple [float , float ],
542
- Tuple [float , float ],
543
- Tuple [float , float ],
544
- ],
545
- ) -> Union [bool , np .bool_ ]:
536
+ def inside_bounds (self , xy : Tuple [float , float ],) -> Union [bool , np .bool_ ]:
546
537
x , y = xy
547
538
(xmin , xmax ), (ymin , ymax ) = self .bounds
548
539
return xmin <= x <= xmax and ymin <= y <= ymax
549
540
550
541
def tell (
551
- self ,
552
- point : Union [
553
- Tuple [int , int ],
554
- Tuple [float , float ],
555
- Tuple [float , float ],
556
- Tuple [float , float ],
557
- ],
558
- value : Union [List [int ], float , float ],
542
+ self , point : Tuple [float , float ], value : Union [float , Iterable [float ]],
559
543
) -> None :
560
544
point = tuple (point )
561
545
self .data [point ] = value
@@ -565,15 +549,7 @@ def tell(
565
549
self ._ip = None
566
550
self ._stack .pop (point , None )
567
551
568
- def tell_pending (
569
- self ,
570
- point : Union [
571
- Tuple [int , int ],
572
- Tuple [float , float ],
573
- Tuple [float , float ],
574
- Tuple [float , float ],
575
- ],
576
- ) -> None :
552
+ def tell_pending (self , point : Tuple [float , float ],) -> None :
577
553
point = tuple (point )
578
554
if not self .inside_bounds (point ):
579
555
return
@@ -622,7 +598,7 @@ def _fill_stack(
622
598
623
599
return points_new , losses_new
624
600
625
- def ask (self , n : int , tell_pending : bool = True ) -> Any :
601
+ def ask (self , n : int , tell_pending : bool = True ) -> Tuple [ np . ndarray , np . ndarray ] :
626
602
# Even if tell_pending is False we add the point such that _fill_stack
627
603
# will return new points, later we remove these points if needed.
628
604
points = list (self ._stack .keys ())
0 commit comments