@@ -29,15 +29,7 @@ def to_list(inp: float) -> List[float]:
29
29
return [inp ]
30
30
31
31
32
- def volume (
33
- simplex : Union [
34
- List [Tuple [float , float ]],
35
- List [Tuple [float , float ]],
36
- List [Tuple [float , float ]],
37
- np .ndarray ,
38
- ],
39
- ys : None = None ,
40
- ) -> float :
32
+ def volume (simplex : List [Tuple [float , float ]], ys : None = None ,) -> float :
41
33
# Notice the parameter ys is there so you can use this volume method as
42
34
# as loss function
43
35
matrix = np .subtract (simplex [:- 1 ], simplex [- 1 ], dtype = float )
@@ -207,13 +199,7 @@ def curvature_loss(simplex, values, value_scale, neighbors, neighbor_values):
207
199
208
200
209
201
def choose_point_in_simplex (
210
- simplex : Union [
211
- List [Union [Tuple [int , int ], Tuple [float , float ]]],
212
- List [Union [Tuple [float , float , float ], Tuple [int , int , int ]]],
213
- List [Tuple [float , float , float ]],
214
- List [Tuple [float , float ]],
215
- ],
216
- transform : Optional [np .ndarray ] = None ,
202
+ simplex : np .ndarray , transform : Optional [np .ndarray ] = None ,
217
203
) -> np .ndarray :
218
204
"""Choose a new point in inside a simplex.
219
205
@@ -318,13 +304,7 @@ class LearnerND(BaseLearner):
318
304
def __init__ (
319
305
self ,
320
306
func : Callable ,
321
- bounds : Union [
322
- Tuple [Tuple [int , int ], Tuple [int , int ], Tuple [int , int ]],
323
- np .ndarray ,
324
- Tuple [Tuple [int , int ], Tuple [int , int ]],
325
- List [Tuple [int , int ]],
326
- ConvexHull ,
327
- ],
307
+ bounds : Union [Tuple [Tuple [float , float ], ...], ConvexHull ],
328
308
loss_per_simplex : Optional [Callable ] = None ,
329
309
) -> None :
330
310
self ._vdim = None
@@ -452,17 +432,7 @@ def points(self) -> np.ndarray:
452
432
"""Get the points from `data` as a numpy array."""
453
433
return np .array (list (self .data .keys ()), dtype = float )
454
434
455
- def tell (
456
- self ,
457
- point : Union [
458
- Tuple [float , float ],
459
- Tuple [int , int ],
460
- Tuple [int , int , int ],
461
- Tuple [float , float , float ],
462
- Tuple [float , float , float ],
463
- ],
464
- value : Union [List [int ], float , float , np .ndarray ],
465
- ) -> None :
435
+ def tell (self , point : Tuple [float , ...], value : Union [float , np .ndarray ],) -> None :
466
436
point = tuple (point )
467
437
468
438
if point in self .data :
@@ -486,7 +456,7 @@ def tell(
486
456
to_delete , to_add = tri .add_point (point , simplex , transform = self ._transform )
487
457
self ._update_losses (to_delete , to_add )
488
458
489
- def _simplex_exists (self , simplex : Any ) -> bool :
459
+ def _simplex_exists (self , simplex : Any ) -> bool : # XXX: specify simplex: Any
490
460
simplex = tuple (sorted (simplex ))
491
461
return simplex in self .tri .simplices
492
462
@@ -547,9 +517,7 @@ def tell_pending(
547
517
self ._update_subsimplex_losses (simpl , to_add )
548
518
549
519
def _try_adding_pending_point_to_simplex (
550
- self ,
551
- point : Union [Tuple [float , float , float ], Tuple [float , float ]],
552
- simplex : Any ,
520
+ self , point : Tuple [float , ...], simplex : Any , # XXX: specify simplex: Any
553
521
) -> Any :
554
522
# try to insert it
555
523
if not self .tri .point_in_simplex (point , simplex ):
@@ -562,7 +530,9 @@ def _try_adding_pending_point_to_simplex(
562
530
self ._pending_to_simplex [point ] = simplex
563
531
return self ._subtriangulations [simplex ].add_point (point )
564
532
565
- def _update_subsimplex_losses (self , simplex : Any , new_subsimplices : Any ) -> None :
533
+ def _update_subsimplex_losses (
534
+ self , simplex : Any , new_subsimplices : Any
535
+ ) -> None : # XXX: specify simplex: Any
566
536
loss = self ._losses [simplex ]
567
537
568
538
loss_density = loss / self .tri .volume (simplex )
@@ -583,14 +553,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
583
553
else :
584
554
return self ._ask_and_tell_pending (n )
585
555
586
- def _ask_bound_point (
587
- self ,
588
- ) -> Union [
589
- Tuple [Tuple [int , int , int ], float ],
590
- Tuple [Tuple [int , int ], float ],
591
- Tuple [Tuple [float , float ], float ],
592
- Tuple [Tuple [float , float , float ], float ],
593
- ]:
556
+ def _ask_bound_point (self ,) -> Tuple [Tuple [float , ...], float ]:
594
557
# get the next bound point that is still available
595
558
new_point = next (
596
559
p
@@ -600,11 +563,7 @@ def _ask_bound_point(
600
563
self .tell_pending (new_point )
601
564
return new_point , np .inf
602
565
603
- def _ask_point_without_known_simplices (
604
- self ,
605
- ) -> Union [
606
- Tuple [Tuple [float , float ], float ], Tuple [Tuple [float , float , float ], float ],
607
- ]:
566
+ def _ask_point_without_known_simplices (self ,) -> Tuple [Tuple [float , ...], float ]:
608
567
assert not self ._bounds_available
609
568
# pick a random point inside the bounds
610
569
# XXX: change this into picking a point based on volume loss
@@ -645,11 +604,7 @@ def _pop_highest_existing_simplex(self) -> Any:
645
604
" be a simplex available if LearnerND.tri() is not None."
646
605
)
647
606
648
- def _ask_best_point (
649
- self ,
650
- ) -> Union [
651
- Tuple [Tuple [float , float ], float ], Tuple [Tuple [float , float , float ], float ],
652
- ]:
607
+ def _ask_best_point (self ,) -> Tuple [Tuple [float , ...], float ]:
653
608
assert self .tri is not None
654
609
655
610
loss , simplex , subsimplex = self ._pop_highest_existing_simplex ()
@@ -676,14 +631,7 @@ def _bounds_available(self) -> bool:
676
631
for p in self ._bounds_points
677
632
)
678
633
679
- def _ask (
680
- self ,
681
- ) -> Union [
682
- Tuple [Tuple [int , int , int ], float ],
683
- Tuple [Tuple [float , float , float ], float ],
684
- Tuple [Tuple [float , float ], float ],
685
- Tuple [Tuple [int , int ], float ],
686
- ]:
634
+ def _ask (self ,) -> Tuple [Tuple [float , ...], float ]:
687
635
if self ._bounds_available :
688
636
return self ._ask_bound_point () # O(1)
689
637
@@ -695,7 +643,7 @@ def _ask(
695
643
696
644
return self ._ask_best_point () # O(log N)
697
645
698
- def _compute_loss (self , simplex : Any ) -> float :
646
+ def _compute_loss (self , simplex : Any ) -> float : # XXX: specify simplex: Any
699
647
# get the loss
700
648
vertices = self .tri .get_vertices (simplex )
701
649
values = [self .data [tuple (v )] for v in vertices ]
0 commit comments