Skip to content

Commit bdb85bf

Browse files
committed
use a normal float instead of np.float64
1 parent f4ecae6 commit bdb85bf

File tree

9 files changed

+130
-185
lines changed

9 files changed

+130
-185
lines changed

adaptive/learner/learner1D.py

Lines changed: 27 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020

2121
@uses_nth_neighbors(0)
2222
def uniform_loss(
23-
xs: Union[Tuple[float, float], Tuple[np.float64, np.float64]],
24-
ys: Union[Tuple[float, float], Tuple[np.float64, np.float64]],
25-
) -> Union[np.float64, float]:
23+
xs: Union[Tuple[float, float], Tuple[float, float]],
24+
ys: Union[Tuple[float, float], Tuple[float, float]],
25+
) -> Union[float, float]:
2626
"""Loss function that samples the domain uniformly.
2727
2828
Works with `~adaptive.Learner1D` only.
@@ -43,18 +43,9 @@ def uniform_loss(
4343

4444
@uses_nth_neighbors(0)
4545
def default_loss(
46-
xs: Union[
47-
Tuple[float, float],
48-
Tuple[np.float64, float],
49-
Tuple[np.float64, np.float64],
50-
Tuple[float, np.float64],
51-
],
52-
ys: Union[
53-
Tuple[float, float],
54-
Tuple[np.ndarray, np.ndarray],
55-
Tuple[np.float64, np.float64],
56-
],
57-
) -> np.float64:
46+
xs: Tuple[float, float],
47+
ys: Union[Tuple[np.ndarray, np.ndarray], Tuple[float, float]],
48+
) -> float:
5849
"""Calculate loss on a single interval.
5950
6051
Currently returns the rescaled length of the interval. If one of the
@@ -71,7 +62,7 @@ def default_loss(
7162

7263

7364
@uses_nth_neighbors(1)
74-
def triangle_loss(xs: Any, ys: Any) -> Union[np.float64, float]:
65+
def triangle_loss(xs: Any, ys: Any) -> Union[float, float]:
7566
xs = [x for x in xs if x is not None]
7667
ys = [y for y in ys if y is not None]
7768

@@ -110,10 +101,8 @@ def curvature_loss(xs, ys):
110101

111102

112103
def linspace(
113-
x_left: Union[int, np.float64, float],
114-
x_right: Union[int, np.float64, float],
115-
n: int,
116-
) -> Union[List[float], List[np.float64]]:
104+
x_left: Union[int, float, float], x_right: Union[int, float, float], n: int,
105+
) -> Union[List[float], List[float]]:
117106
"""This is equivalent to
118107
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
119108
but it is 15-30 times faster for small 'n'."""
@@ -136,7 +125,7 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
136125

137126

138127
def _get_intervals(
139-
x: Union[int, np.float64, float], neighbors: SortedDict, nth_neighbors: int
128+
x: Union[int, float, float], neighbors: SortedDict, nth_neighbors: int
140129
) -> Any:
141130
nn = nth_neighbors
142131
i = neighbors.index(x)
@@ -262,38 +251,36 @@ def npoints(self) -> int:
262251
return len(self.data)
263252

264253
@cache_latest
265-
def loss(self, real: bool = True) -> Union[int, np.float64, float]:
254+
def loss(self, real: bool = True) -> Union[int, float, float]:
266255
losses = self.losses if real else self.losses_combined
267256
if not losses:
268257
return np.inf
269258
max_interval, max_loss = losses.peekitem(0)
270259
return max_loss
271260

272261
def _scale_x(
273-
self, x: Optional[Union[float, int, np.float64]]
274-
) -> Optional[Union[float, np.float64]]:
262+
self, x: Optional[Union[float, int, float]]
263+
) -> Optional[Union[float, float]]:
275264
if x is None:
276265
return None
277266
return x / self._scale[0]
278267

279268
def _scale_y(
280-
self, y: Optional[Union[int, np.ndarray, np.float64, float]]
281-
) -> Optional[Union[float, np.float64, np.ndarray]]:
269+
self, y: Optional[Union[int, np.ndarray, float, float]]
270+
) -> Optional[Union[float, float, np.ndarray]]:
282271
if y is None:
283272
return None
284273
y_scale = self._scale[1] or 1
285274
return y / y_scale
286275

287-
def _get_point_by_index(self, ind: int) -> Optional[Union[int, np.float64, float]]:
276+
def _get_point_by_index(self, ind: int) -> Optional[Union[int, float, float]]:
288277
if ind < 0 or ind >= len(self.neighbors):
289278
return None
290279
return self.neighbors.keys()[ind]
291280

292281
def _get_loss_in_interval(
293-
self,
294-
x_left: Union[int, np.float64, float],
295-
x_right: Union[int, np.float64, float],
296-
) -> Union[int, np.float64, float]:
282+
self, x_left: Union[int, float, float], x_right: Union[int, float, float],
283+
) -> Union[int, float, float]:
297284
assert x_left is not None and x_right is not None
298285

299286
if x_right - x_left < self._dx_eps:
@@ -314,9 +301,7 @@ def _get_loss_in_interval(
314301
return self.loss_per_interval(xs_scaled, ys_scaled)
315302

316303
def _update_interpolated_loss_in_interval(
317-
self,
318-
x_left: Union[int, np.float64, float],
319-
x_right: Union[int, np.float64, float],
304+
self, x_left: Union[int, float, float], x_right: Union[int, float, float],
320305
) -> None:
321306
if x_left is None or x_right is None:
322307
return
@@ -333,9 +318,7 @@ def _update_interpolated_loss_in_interval(
333318
self.losses_combined[a, b] = (b - a) * loss / dx
334319
a = b
335320

336-
def _update_losses(
337-
self, x: Union[int, np.float64, float], real: bool = True
338-
) -> None:
321+
def _update_losses(self, x: Union[int, float, float], real: bool = True) -> None:
339322
"""Update all losses that depend on x"""
340323
# When we add a new point x, we should update the losses
341324
# (x_left, x_right) are the "real" neighbors of 'x'.
@@ -378,7 +361,7 @@ def _update_losses(
378361
self.losses_combined[x, b] = float("inf")
379362

380363
@staticmethod
381-
def _find_neighbors(x: Union[int, np.float64, float], neighbors: SortedDict) -> Any:
364+
def _find_neighbors(x: Union[int, float, float], neighbors: SortedDict) -> Any:
382365
if x in neighbors:
383366
return neighbors[x]
384367
pos = neighbors.bisect_left(x)
@@ -388,7 +371,7 @@ def _find_neighbors(x: Union[int, np.float64, float], neighbors: SortedDict) ->
388371
return x_left, x_right
389372

390373
def _update_neighbors(
391-
self, x: Union[int, np.float64, float], neighbors: SortedDict
374+
self, x: Union[int, float, float], neighbors: SortedDict
392375
) -> None:
393376
if x not in neighbors: # The point is new
394377
x_left, x_right = self._find_neighbors(x, neighbors)
@@ -397,9 +380,7 @@ def _update_neighbors(
397380
neighbors.get(x_right, [None, None])[0] = x
398381

399382
def _update_scale(
400-
self,
401-
x: Union[int, np.float64, float],
402-
y: Union[float, int, np.float64, np.ndarray],
383+
self, x: Union[int, float, float], y: Union[float, int, float, np.ndarray],
403384
) -> None:
404385
"""Update the scale with which the x and y-values are scaled.
405386
@@ -427,7 +408,7 @@ def _update_scale(
427408
self._bbox[1][1] = max(self._bbox[1][1], y)
428409
self._scale[1] = self._bbox[1][1] - self._bbox[1][0]
429410

430-
def tell(self, x: Union[int, np.float64, float], y: Any) -> None:
411+
def tell(self, x: Union[int, float, float], y: Any) -> None:
431412
if x in self.data:
432413
# The point is already evaluated before
433414
return
@@ -462,7 +443,7 @@ def tell(self, x: Union[int, np.float64, float], y: Any) -> None:
462443

463444
self._oldscale = deepcopy(self._scale)
464445

465-
def tell_pending(self, x: Union[int, np.float64, float]) -> None:
446+
def tell_pending(self, x: Union[int, float, float]) -> None:
466447
if x in self.data:
467448
# The point is already evaluated before
468449
return
@@ -678,7 +659,7 @@ def _set_data(self, data: Dict[Union[int, float], float]) -> None:
678659
self.tell_many(*zip(*data.items()))
679660

680661

681-
def loss_manager(x_scale: Union[int, np.float64, float]) -> ItemSortedDict:
662+
def loss_manager(x_scale: Union[int, float, float]) -> ItemSortedDict:
682663
def sort_key(ival, loss):
683664
loss, ival = finite_loss(ival, loss, x_scale)
684665
return -loss, ival
@@ -688,9 +669,7 @@ def sort_key(ival, loss):
688669

689670

690671
def finite_loss(
691-
ival: Any,
692-
loss: Union[int, np.float64, float],
693-
x_scale: Union[int, np.float64, float],
672+
ival: Any, loss: Union[int, float, float], x_scale: Union[int, float, float],
694673
) -> Any:
695674
"""Get the socalled finite_loss of an interval in order to be able to
696675
sort intervals that have infinite loss."""

adaptive/learner/learner2D.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,9 @@ def inside_bounds(
538538
self,
539539
xy: Union[
540540
Tuple[int, int],
541-
Tuple[np.float64, float],
542-
Tuple[np.float64, np.float64],
543-
Tuple[float, np.float64],
541+
Tuple[float, float],
542+
Tuple[float, float],
543+
Tuple[float, float],
544544
],
545545
) -> Union[bool, np.bool_]:
546546
x, y = xy
@@ -551,11 +551,11 @@ def tell(
551551
self,
552552
point: Union[
553553
Tuple[int, int],
554-
Tuple[np.float64, float],
555-
Tuple[np.float64, np.float64],
556-
Tuple[float, np.float64],
554+
Tuple[float, float],
555+
Tuple[float, float],
556+
Tuple[float, float],
557557
],
558-
value: Union[List[int], np.float64, float],
558+
value: Union[List[int], float, float],
559559
) -> None:
560560
point = tuple(point)
561561
self.data[point] = value
@@ -569,9 +569,9 @@ def tell_pending(
569569
self,
570570
point: Union[
571571
Tuple[int, int],
572-
Tuple[np.float64, float],
573-
Tuple[np.float64, np.float64],
574-
Tuple[float, np.float64],
572+
Tuple[float, float],
573+
Tuple[float, float],
574+
Tuple[float, float],
575575
],
576576
) -> None:
577577
point = tuple(point)
@@ -584,25 +584,13 @@ def tell_pending(
584584
def _fill_stack(
585585
self, stack_till: int = 1
586586
) -> Union[
587-
Tuple[List[Tuple[np.float64, np.float64]], List[np.float64]],
587+
Tuple[List[Tuple[float, float]], List[float]],
588588
Tuple[
589-
List[
590-
Union[
591-
Tuple[np.float64, np.float64],
592-
Tuple[float, np.float64],
593-
Tuple[np.float64, float],
594-
]
595-
],
596-
List[np.float64],
597-
],
598-
Tuple[
599-
List[Union[Tuple[float, np.float64], Tuple[np.float64, np.float64]]],
600-
List[np.float64],
601-
],
602-
Tuple[
603-
List[Union[Tuple[np.float64, np.float64], Tuple[np.float64, float]]],
604-
List[np.float64],
589+
List[Union[Tuple[float, float], Tuple[float, float], Tuple[float, float]]],
590+
List[float],
605591
],
592+
Tuple[List[Union[Tuple[float, float], Tuple[float, float]]], List[float]],
593+
Tuple[List[Union[Tuple[float, float], Tuple[float, float]]], List[float]],
606594
]:
607595
if len(self.data) + len(self.pending_points) < self.ndim + 1:
608596
raise ValueError("too few points...")
@@ -673,7 +661,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
673661
return points[:n], loss_improvements[:n]
674662

675663
@cache_latest
676-
def loss(self, real: bool = True) -> np.float64:
664+
def loss(self, real: bool = True) -> float:
677665
if not self.bounds_are_done:
678666
return np.inf
679667
ip = self.interpolator(scaled=True) if real else self._interpolator_combined()

0 commit comments

Comments
 (0)