Skip to content

Commit de4a8c7

Browse files
committed
type hint fixes for adaptive/learner/learnerND.py
1 parent e92c4fb commit de4a8c7

File tree

1 file changed

+14
-66
lines changed

1 file changed

+14
-66
lines changed

adaptive/learner/learnerND.py

Lines changed: 14 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,7 @@ def to_list(inp: float) -> List[float]:
2929
return [inp]
3030

3131

32-
def volume(
33-
simplex: Union[
34-
List[Tuple[float, float]],
35-
List[Tuple[float, float]],
36-
List[Tuple[float, float]],
37-
np.ndarray,
38-
],
39-
ys: None = None,
40-
) -> float:
32+
def volume(simplex: List[Tuple[float, float]], ys: None = None,) -> float:
4133
# Notice the parameter ys is there so you can use this volume method as
4234
# as loss function
4335
matrix = np.subtract(simplex[:-1], simplex[-1], dtype=float)
@@ -207,13 +199,7 @@ def curvature_loss(simplex, values, value_scale, neighbors, neighbor_values):
207199

208200

209201
def choose_point_in_simplex(
210-
simplex: Union[
211-
List[Union[Tuple[int, int], Tuple[float, float]]],
212-
List[Union[Tuple[float, float, float], Tuple[int, int, int]]],
213-
List[Tuple[float, float, float]],
214-
List[Tuple[float, float]],
215-
],
216-
transform: Optional[np.ndarray] = None,
202+
simplex: np.ndarray, transform: Optional[np.ndarray] = None,
217203
) -> np.ndarray:
218204
"""Choose a new point in inside a simplex.
219205
@@ -318,13 +304,7 @@ class LearnerND(BaseLearner):
318304
def __init__(
319305
self,
320306
func: Callable,
321-
bounds: Union[
322-
Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]],
323-
np.ndarray,
324-
Tuple[Tuple[int, int], Tuple[int, int]],
325-
List[Tuple[int, int]],
326-
ConvexHull,
327-
],
307+
bounds: Union[Tuple[Tuple[float, float], ...], ConvexHull],
328308
loss_per_simplex: Optional[Callable] = None,
329309
) -> None:
330310
self._vdim = None
@@ -452,17 +432,7 @@ def points(self) -> np.ndarray:
452432
"""Get the points from `data` as a numpy array."""
453433
return np.array(list(self.data.keys()), dtype=float)
454434

455-
def tell(
456-
self,
457-
point: Union[
458-
Tuple[float, float],
459-
Tuple[int, int],
460-
Tuple[int, int, int],
461-
Tuple[float, float, float],
462-
Tuple[float, float, float],
463-
],
464-
value: Union[List[int], float, float, np.ndarray],
465-
) -> None:
435+
def tell(self, point: Tuple[float, ...], value: Union[float, np.ndarray],) -> None:
466436
point = tuple(point)
467437

468438
if point in self.data:
@@ -486,7 +456,7 @@ def tell(
486456
to_delete, to_add = tri.add_point(point, simplex, transform=self._transform)
487457
self._update_losses(to_delete, to_add)
488458

489-
def _simplex_exists(self, simplex: Any) -> bool:
459+
def _simplex_exists(self, simplex: Any) -> bool: # XXX: specify simplex: Any
490460
simplex = tuple(sorted(simplex))
491461
return simplex in self.tri.simplices
492462

@@ -547,9 +517,7 @@ def tell_pending(
547517
self._update_subsimplex_losses(simpl, to_add)
548518

549519
def _try_adding_pending_point_to_simplex(
550-
self,
551-
point: Union[Tuple[float, float, float], Tuple[float, float]],
552-
simplex: Any,
520+
self, point: Tuple[float, ...], simplex: Any, # XXX: specify simplex: Any
553521
) -> Any:
554522
# try to insert it
555523
if not self.tri.point_in_simplex(point, simplex):
@@ -562,7 +530,9 @@ def _try_adding_pending_point_to_simplex(
562530
self._pending_to_simplex[point] = simplex
563531
return self._subtriangulations[simplex].add_point(point)
564532

565-
def _update_subsimplex_losses(self, simplex: Any, new_subsimplices: Any) -> None:
533+
def _update_subsimplex_losses(
534+
self, simplex: Any, new_subsimplices: Any
535+
) -> None: # XXX: specify simplex: Any
566536
loss = self._losses[simplex]
567537

568538
loss_density = loss / self.tri.volume(simplex)
@@ -583,14 +553,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
583553
else:
584554
return self._ask_and_tell_pending(n)
585555

586-
def _ask_bound_point(
587-
self,
588-
) -> Union[
589-
Tuple[Tuple[int, int, int], float],
590-
Tuple[Tuple[int, int], float],
591-
Tuple[Tuple[float, float], float],
592-
Tuple[Tuple[float, float, float], float],
593-
]:
556+
def _ask_bound_point(self,) -> Tuple[Tuple[float, ...], float]:
594557
# get the next bound point that is still available
595558
new_point = next(
596559
p
@@ -600,11 +563,7 @@ def _ask_bound_point(
600563
self.tell_pending(new_point)
601564
return new_point, np.inf
602565

603-
def _ask_point_without_known_simplices(
604-
self,
605-
) -> Union[
606-
Tuple[Tuple[float, float], float], Tuple[Tuple[float, float, float], float],
607-
]:
566+
def _ask_point_without_known_simplices(self,) -> Tuple[Tuple[float, ...], float]:
608567
assert not self._bounds_available
609568
# pick a random point inside the bounds
610569
# XXX: change this into picking a point based on volume loss
@@ -645,11 +604,7 @@ def _pop_highest_existing_simplex(self) -> Any:
645604
" be a simplex available if LearnerND.tri() is not None."
646605
)
647606

648-
def _ask_best_point(
649-
self,
650-
) -> Union[
651-
Tuple[Tuple[float, float], float], Tuple[Tuple[float, float, float], float],
652-
]:
607+
def _ask_best_point(self,) -> Tuple[Tuple[float, ...], float]:
653608
assert self.tri is not None
654609

655610
loss, simplex, subsimplex = self._pop_highest_existing_simplex()
@@ -676,14 +631,7 @@ def _bounds_available(self) -> bool:
676631
for p in self._bounds_points
677632
)
678633

679-
def _ask(
680-
self,
681-
) -> Union[
682-
Tuple[Tuple[int, int, int], float],
683-
Tuple[Tuple[float, float, float], float],
684-
Tuple[Tuple[float, float], float],
685-
Tuple[Tuple[int, int], float],
686-
]:
634+
def _ask(self,) -> Tuple[Tuple[float, ...], float]:
687635
if self._bounds_available:
688636
return self._ask_bound_point() # O(1)
689637

@@ -695,7 +643,7 @@ def _ask(
695643

696644
return self._ask_best_point() # O(log N)
697645

698-
def _compute_loss(self, simplex: Any) -> float:
646+
def _compute_loss(self, simplex: Any) -> float: # XXX: specify simplex: Any
699647
# get the loss
700648
vertices = self.tri.get_vertices(simplex)
701649
values = [self.data[tuple(v)] for v in vertices]

0 commit comments

Comments
 (0)