Skip to content

Commit e003a8b

Browse files
committed
type hint fixes for adaptive/learner/learnerND.py
1 parent e92c4fb commit e003a8b

File tree

1 file changed

+16
-87
lines changed

1 file changed

+16
-87
lines changed

adaptive/learner/learnerND.py

Lines changed: 16 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,7 @@ def to_list(inp: float) -> List[float]:
2929
return [inp]
3030

3131

32-
def volume(
33-
simplex: Union[
34-
List[Tuple[float, float]],
35-
List[Tuple[float, float]],
36-
List[Tuple[float, float]],
37-
np.ndarray,
38-
],
39-
ys: None = None,
40-
) -> float:
32+
def volume(simplex: List[Tuple[float, float]], ys: None = None,) -> float:
4133
# Notice the parameter ys is there so you can use this volume method as
4234
# as loss function
4335
matrix = np.subtract(simplex[:-1], simplex[-1], dtype=float)
@@ -207,13 +199,7 @@ def curvature_loss(simplex, values, value_scale, neighbors, neighbor_values):
207199

208200

209201
def choose_point_in_simplex(
210-
simplex: Union[
211-
List[Union[Tuple[int, int], Tuple[float, float]]],
212-
List[Union[Tuple[float, float, float], Tuple[int, int, int]]],
213-
List[Tuple[float, float, float]],
214-
List[Tuple[float, float]],
215-
],
216-
transform: Optional[np.ndarray] = None,
202+
simplex: np.ndarray, transform: Optional[np.ndarray] = None,
217203
) -> np.ndarray:
218204
"""Choose a new point in inside a simplex.
219205
@@ -318,13 +304,7 @@ class LearnerND(BaseLearner):
318304
def __init__(
319305
self,
320306
func: Callable,
321-
bounds: Union[
322-
Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]],
323-
np.ndarray,
324-
Tuple[Tuple[int, int], Tuple[int, int]],
325-
List[Tuple[int, int]],
326-
ConvexHull,
327-
],
307+
bounds: Union[Tuple[Tuple[float, float], ...], ConvexHull],
328308
loss_per_simplex: Optional[Callable] = None,
329309
) -> None:
330310
self._vdim = None
@@ -452,17 +432,7 @@ def points(self) -> np.ndarray:
452432
"""Get the points from `data` as a numpy array."""
453433
return np.array(list(self.data.keys()), dtype=float)
454434

455-
def tell(
456-
self,
457-
point: Union[
458-
Tuple[float, float],
459-
Tuple[int, int],
460-
Tuple[int, int, int],
461-
Tuple[float, float, float],
462-
Tuple[float, float, float],
463-
],
464-
value: Union[List[int], float, float, np.ndarray],
465-
) -> None:
435+
def tell(self, point: Tuple[float, ...], value: Union[float, np.ndarray],) -> None:
466436
point = tuple(point)
467437

468438
if point in self.data:
@@ -486,20 +456,11 @@ def tell(
486456
to_delete, to_add = tri.add_point(point, simplex, transform=self._transform)
487457
self._update_losses(to_delete, to_add)
488458

489-
def _simplex_exists(self, simplex: Any) -> bool:
459+
def _simplex_exists(self, simplex: Any) -> bool: # XXX: specify simplex: Any
490460
simplex = tuple(sorted(simplex))
491461
return simplex in self.tri.simplices
492462

493-
def inside_bounds(
494-
self,
495-
point: Union[
496-
Tuple[float, float],
497-
Tuple[float, float, float],
498-
Tuple[int, int, int],
499-
Tuple[int, int],
500-
Tuple[float, float, float],
501-
],
502-
) -> Union[bool, np.bool_]:
463+
def inside_bounds(self, point: Tuple[float, ...],) -> Union[bool, np.bool_]:
503464
"""Check whether a point is inside the bounds."""
504465
if hasattr(self, "_interior"):
505466
return self._interior.find_simplex(point, tol=1e-8) >= 0
@@ -509,17 +470,7 @@ def inside_bounds(
509470
(mn - eps) <= p <= (mx + eps) for p, (mn, mx) in zip(point, self._bbox)
510471
)
511472

512-
def tell_pending(
513-
self,
514-
point: Union[
515-
Tuple[int, int],
516-
Tuple[float, float, float],
517-
Tuple[float, float],
518-
Tuple[int, int, int],
519-
],
520-
*,
521-
simplex=None,
522-
) -> None:
473+
def tell_pending(self, point: Tuple[float, ...], *, simplex=None,) -> None:
523474
point = tuple(point)
524475
if not self.inside_bounds(point):
525476
return
@@ -547,9 +498,7 @@ def tell_pending(
547498
self._update_subsimplex_losses(simpl, to_add)
548499

549500
def _try_adding_pending_point_to_simplex(
550-
self,
551-
point: Union[Tuple[float, float, float], Tuple[float, float]],
552-
simplex: Any,
501+
self, point: Tuple[float, ...], simplex: Any, # XXX: specify simplex: Any
553502
) -> Any:
554503
# try to insert it
555504
if not self.tri.point_in_simplex(point, simplex):
@@ -562,7 +511,9 @@ def _try_adding_pending_point_to_simplex(
562511
self._pending_to_simplex[point] = simplex
563512
return self._subtriangulations[simplex].add_point(point)
564513

565-
def _update_subsimplex_losses(self, simplex: Any, new_subsimplices: Any) -> None:
514+
def _update_subsimplex_losses(
515+
self, simplex: Any, new_subsimplices: Any
516+
) -> None: # XXX: specify simplex: Any
566517
loss = self._losses[simplex]
567518

568519
loss_density = loss / self.tri.volume(simplex)
@@ -583,14 +534,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
583534
else:
584535
return self._ask_and_tell_pending(n)
585536

586-
def _ask_bound_point(
587-
self,
588-
) -> Union[
589-
Tuple[Tuple[int, int, int], float],
590-
Tuple[Tuple[int, int], float],
591-
Tuple[Tuple[float, float], float],
592-
Tuple[Tuple[float, float, float], float],
593-
]:
537+
def _ask_bound_point(self,) -> Tuple[Tuple[float, ...], float]:
594538
# get the next bound point that is still available
595539
new_point = next(
596540
p
@@ -600,11 +544,7 @@ def _ask_bound_point(
600544
self.tell_pending(new_point)
601545
return new_point, np.inf
602546

603-
def _ask_point_without_known_simplices(
604-
self,
605-
) -> Union[
606-
Tuple[Tuple[float, float], float], Tuple[Tuple[float, float, float], float],
607-
]:
547+
def _ask_point_without_known_simplices(self,) -> Tuple[Tuple[float, ...], float]:
608548
assert not self._bounds_available
609549
# pick a random point inside the bounds
610550
# XXX: change this into picking a point based on volume loss
@@ -645,11 +585,7 @@ def _pop_highest_existing_simplex(self) -> Any:
645585
" be a simplex available if LearnerND.tri() is not None."
646586
)
647587

648-
def _ask_best_point(
649-
self,
650-
) -> Union[
651-
Tuple[Tuple[float, float], float], Tuple[Tuple[float, float, float], float],
652-
]:
588+
def _ask_best_point(self,) -> Tuple[Tuple[float, ...], float]:
653589
assert self.tri is not None
654590

655591
loss, simplex, subsimplex = self._pop_highest_existing_simplex()
@@ -676,14 +612,7 @@ def _bounds_available(self) -> bool:
676612
for p in self._bounds_points
677613
)
678614

679-
def _ask(
680-
self,
681-
) -> Union[
682-
Tuple[Tuple[int, int, int], float],
683-
Tuple[Tuple[float, float, float], float],
684-
Tuple[Tuple[float, float], float],
685-
Tuple[Tuple[int, int], float],
686-
]:
615+
def _ask(self,) -> Tuple[Tuple[float, ...], float]:
687616
if self._bounds_available:
688617
return self._ask_bound_point() # O(1)
689618

@@ -695,7 +624,7 @@ def _ask(
695624

696625
return self._ask_best_point() # O(log N)
697626

698-
def _compute_loss(self, simplex: Any) -> float:
627+
def _compute_loss(self, simplex: Any) -> float: # XXX: specify simplex: Any
699628
# get the loss
700629
vertices = self.tri.get_vertices(simplex)
701630
values = [self.data[tuple(v)] for v in vertices]

0 commit comments

Comments
 (0)