@@ -29,15 +29,7 @@ def to_list(inp: float) -> List[float]:
29
29
return [inp ]
30
30
31
31
32
- def volume (
33
- simplex : Union [
34
- List [Tuple [float , float ]],
35
- List [Tuple [float , float ]],
36
- List [Tuple [float , float ]],
37
- np .ndarray ,
38
- ],
39
- ys : None = None ,
40
- ) -> float :
32
+ def volume (simplex : List [Tuple [float , float ]], ys : None = None ,) -> float :
41
33
# Notice the parameter ys is there so you can use this volume method as
42
34
# as loss function
43
35
matrix = np .subtract (simplex [:- 1 ], simplex [- 1 ], dtype = float )
@@ -207,13 +199,7 @@ def curvature_loss(simplex, values, value_scale, neighbors, neighbor_values):
207
199
208
200
209
201
def choose_point_in_simplex (
210
- simplex : Union [
211
- List [Union [Tuple [int , int ], Tuple [float , float ]]],
212
- List [Union [Tuple [float , float , float ], Tuple [int , int , int ]]],
213
- List [Tuple [float , float , float ]],
214
- List [Tuple [float , float ]],
215
- ],
216
- transform : Optional [np .ndarray ] = None ,
202
+ simplex : np .ndarray , transform : Optional [np .ndarray ] = None ,
217
203
) -> np .ndarray :
218
204
"""Choose a new point in inside a simplex.
219
205
@@ -318,13 +304,7 @@ class LearnerND(BaseLearner):
318
304
def __init__ (
319
305
self ,
320
306
func : Callable ,
321
- bounds : Union [
322
- Tuple [Tuple [int , int ], Tuple [int , int ], Tuple [int , int ]],
323
- np .ndarray ,
324
- Tuple [Tuple [int , int ], Tuple [int , int ]],
325
- List [Tuple [int , int ]],
326
- ConvexHull ,
327
- ],
307
+ bounds : Union [Tuple [Tuple [float , float ], ...], ConvexHull ],
328
308
loss_per_simplex : Optional [Callable ] = None ,
329
309
) -> None :
330
310
self ._vdim = None
@@ -452,17 +432,7 @@ def points(self) -> np.ndarray:
452
432
"""Get the points from `data` as a numpy array."""
453
433
return np .array (list (self .data .keys ()), dtype = float )
454
434
455
- def tell (
456
- self ,
457
- point : Union [
458
- Tuple [float , float ],
459
- Tuple [int , int ],
460
- Tuple [int , int , int ],
461
- Tuple [float , float , float ],
462
- Tuple [float , float , float ],
463
- ],
464
- value : Union [List [int ], float , float , np .ndarray ],
465
- ) -> None :
435
+ def tell (self , point : Tuple [float , ...], value : Union [float , np .ndarray ],) -> None :
466
436
point = tuple (point )
467
437
468
438
if point in self .data :
@@ -486,20 +456,11 @@ def tell(
486
456
to_delete , to_add = tri .add_point (point , simplex , transform = self ._transform )
487
457
self ._update_losses (to_delete , to_add )
488
458
489
- def _simplex_exists (self , simplex : Any ) -> bool :
459
+ def _simplex_exists (self , simplex : Any ) -> bool : # XXX: specify simplex: Any
490
460
simplex = tuple (sorted (simplex ))
491
461
return simplex in self .tri .simplices
492
462
493
- def inside_bounds (
494
- self ,
495
- point : Union [
496
- Tuple [float , float ],
497
- Tuple [float , float , float ],
498
- Tuple [int , int , int ],
499
- Tuple [int , int ],
500
- Tuple [float , float , float ],
501
- ],
502
- ) -> Union [bool , np .bool_ ]:
463
+ def inside_bounds (self , point : Tuple [float , ...],) -> Union [bool , np .bool_ ]:
503
464
"""Check whether a point is inside the bounds."""
504
465
if hasattr (self , "_interior" ):
505
466
return self ._interior .find_simplex (point , tol = 1e-8 ) >= 0
@@ -509,17 +470,7 @@ def inside_bounds(
509
470
(mn - eps ) <= p <= (mx + eps ) for p , (mn , mx ) in zip (point , self ._bbox )
510
471
)
511
472
512
- def tell_pending (
513
- self ,
514
- point : Union [
515
- Tuple [int , int ],
516
- Tuple [float , float , float ],
517
- Tuple [float , float ],
518
- Tuple [int , int , int ],
519
- ],
520
- * ,
521
- simplex = None ,
522
- ) -> None :
473
+ def tell_pending (self , point : Tuple [float , ...], * , simplex = None ,) -> None :
523
474
point = tuple (point )
524
475
if not self .inside_bounds (point ):
525
476
return
@@ -547,9 +498,7 @@ def tell_pending(
547
498
self ._update_subsimplex_losses (simpl , to_add )
548
499
549
500
def _try_adding_pending_point_to_simplex (
550
- self ,
551
- point : Union [Tuple [float , float , float ], Tuple [float , float ]],
552
- simplex : Any ,
501
+ self , point : Tuple [float , ...], simplex : Any , # XXX: specify simplex: Any
553
502
) -> Any :
554
503
# try to insert it
555
504
if not self .tri .point_in_simplex (point , simplex ):
@@ -562,7 +511,9 @@ def _try_adding_pending_point_to_simplex(
562
511
self ._pending_to_simplex [point ] = simplex
563
512
return self ._subtriangulations [simplex ].add_point (point )
564
513
565
- def _update_subsimplex_losses (self , simplex : Any , new_subsimplices : Any ) -> None :
514
+ def _update_subsimplex_losses (
515
+ self , simplex : Any , new_subsimplices : Any
516
+ ) -> None : # XXX: specify simplex: Any
566
517
loss = self ._losses [simplex ]
567
518
568
519
loss_density = loss / self .tri .volume (simplex )
@@ -583,14 +534,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
583
534
else :
584
535
return self ._ask_and_tell_pending (n )
585
536
586
- def _ask_bound_point (
587
- self ,
588
- ) -> Union [
589
- Tuple [Tuple [int , int , int ], float ],
590
- Tuple [Tuple [int , int ], float ],
591
- Tuple [Tuple [float , float ], float ],
592
- Tuple [Tuple [float , float , float ], float ],
593
- ]:
537
+ def _ask_bound_point (self ,) -> Tuple [Tuple [float , ...], float ]:
594
538
# get the next bound point that is still available
595
539
new_point = next (
596
540
p
@@ -600,11 +544,7 @@ def _ask_bound_point(
600
544
self .tell_pending (new_point )
601
545
return new_point , np .inf
602
546
603
- def _ask_point_without_known_simplices (
604
- self ,
605
- ) -> Union [
606
- Tuple [Tuple [float , float ], float ], Tuple [Tuple [float , float , float ], float ],
607
- ]:
547
+ def _ask_point_without_known_simplices (self ,) -> Tuple [Tuple [float , ...], float ]:
608
548
assert not self ._bounds_available
609
549
# pick a random point inside the bounds
610
550
# XXX: change this into picking a point based on volume loss
@@ -645,11 +585,7 @@ def _pop_highest_existing_simplex(self) -> Any:
645
585
" be a simplex available if LearnerND.tri() is not None."
646
586
)
647
587
648
- def _ask_best_point (
649
- self ,
650
- ) -> Union [
651
- Tuple [Tuple [float , float ], float ], Tuple [Tuple [float , float , float ], float ],
652
- ]:
588
+ def _ask_best_point (self ,) -> Tuple [Tuple [float , ...], float ]:
653
589
assert self .tri is not None
654
590
655
591
loss , simplex , subsimplex = self ._pop_highest_existing_simplex ()
@@ -676,14 +612,7 @@ def _bounds_available(self) -> bool:
676
612
for p in self ._bounds_points
677
613
)
678
614
679
- def _ask (
680
- self ,
681
- ) -> Union [
682
- Tuple [Tuple [int , int , int ], float ],
683
- Tuple [Tuple [float , float , float ], float ],
684
- Tuple [Tuple [float , float ], float ],
685
- Tuple [Tuple [int , int ], float ],
686
- ]:
615
+ def _ask (self ,) -> Tuple [Tuple [float , ...], float ]:
687
616
if self ._bounds_available :
688
617
return self ._ask_bound_point () # O(1)
689
618
@@ -695,7 +624,7 @@ def _ask(
695
624
696
625
return self ._ask_best_point () # O(log N)
697
626
698
- def _compute_loss (self , simplex : Any ) -> float :
627
+ def _compute_loss (self , simplex : Any ) -> float : # XXX: specify simplex: Any
699
628
# get the loss
700
629
vertices = self .tri .get_vertices (simplex )
701
630
values = [self .data [tuple (v )] for v in vertices ]
0 commit comments